/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance;

import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsRegistry;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.logging.MainLoggingConfigurator;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;

public class GuardrailsPipelineUtils {
    public static boolean needsNonParallelProcessing(@Nullable GuardrailsPipelineSettings settings) {
        if (settings == null) {
            return false;
        }
        for (GuardrailsPipelineSettings.GuardrailsPipelineElement elt : settings.guardrails) {
            if (!elt.enabled) continue;
            GuardrailMeta meta = GuardrailsRegistry.getMeta(elt.type);
            EnumSet<GuardrailMeta.GuardrailFlag> flags = meta.getFlags(elt);
            if (flags.contains((Object)GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES) && flags.contains((Object)GuardrailMeta.GuardrailFlag.MAY_REQUEST_RETRY_ON_RESPONSES)) {
                return true;
            }
            if (!flags.contains((Object)GuardrailMeta.GuardrailFlag.OPERATES_ON_QUERIES) || !flags.contains((Object)GuardrailMeta.GuardrailFlag.MAY_RESPOND_DIRECTLY_TO_QUERIES)) continue;
            return true;
        }
        return false;
    }

    public static boolean needsNonStreamedNonParallelProcessing(@Nullable GuardrailsPipelineSettings settings) {
        if (settings == null) {
            return false;
        }
        for (GuardrailsPipelineSettings.GuardrailsPipelineElement elt : settings.guardrails) {
            GuardrailMeta meta;
            EnumSet<GuardrailMeta.GuardrailFlag> flags;
            if (!elt.enabled || !(flags = (meta = GuardrailsRegistry.getMeta(elt.type)).getFlags(elt)).contains((Object)GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES) || flags.contains((Object)GuardrailMeta.GuardrailFlag.CAN_STREAM_RESPONSES) && !flags.contains((Object)GuardrailMeta.GuardrailFlag.MAY_REQUEST_RETRY_ON_RESPONSES) && !flags.contains((Object)GuardrailMeta.GuardrailFlag.MAY_RESPOND_DIRECTLY_TO_QUERIES)) continue;
            return true;
        }
        return false;
    }

    public static GuardrailsPipelineSettings getConnectionAndLLMLevelSettings(AuthCtx authCtx, String contextProjectKey, LLMStructuredRef ref) throws IOException, DKUSecurityException {
        TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
        AbstractLLMConnection.AbstractLLMConnectionParams params = null;
        try (Transaction t = transactionService.retrieveOrBeginRead();){
            AbstractLLMConnection conn = LLMClientFactory.getFinalConnectionForGovernance(authCtx, contextProjectKey, ref);
            if (conn != null && conn instanceof AbstractLLMConnection) {
                params = conn.getLLMConnectionParams();
            }
        }
        if (params != null) {
            return GuardrailsPipelineUtils.mergeEnforcementSettings(params.guardrailsPipelineSettings, GuardrailsPipelineUtils.getLLMLevelSettings(authCtx, contextProjectKey, ref));
        }
        return GuardrailsPipelineUtils.getLLMLevelSettings(authCtx, contextProjectKey, ref);
    }

    public static GuardrailsPipelineSettings getLLMLevelSettings(AuthCtx authCtx, String contextProjectKey, LLMStructuredRef ref) throws IOException, DKUSecurityException {
        switch (ref.type) {
            case RETRIEVAL_AUGMENTED: {
                return null;
            }
            case SAVED_MODEL_AGENT: {
                TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
                try (Transaction t = transactionService.retrieveOrBeginRead();){
                    SavedModel sm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getMandatory(AnyLoc.resolveSmart(contextProjectKey, ref.savedModelSmartId));
                    String smVersion = ref.savedModelVersionId != null ? ref.savedModelVersionId : sm.activeVersion;
                    SavedModel.SavedModelInlineVersion smiv = sm.getVersion(smVersion).orElseThrow();
                    GuardrailsPipelineSettings guardrailsPipelineSettings = smiv.guardrailsPipelineSettings;
                    return guardrailsPipelineSettings;
                }
            }
        }
        return null;
    }

    public static GuardrailsPipelineSettings mergeEnforcementSettings(@Nullable GuardrailsPipelineSettings firstLevel, @Nullable GuardrailsPipelineSettings secondLevel) {
        if (firstLevel == null) {
            return secondLevel;
        }
        if (secondLevel == null) {
            return firstLevel;
        }
        GuardrailsPipelineSettings result = new GuardrailsPipelineSettings();
        result.guardrails = Stream.concat(firstLevel.guardrails.stream(), secondLevel.guardrails.stream()).collect(Collectors.toList());
        return result;
    }

    public static void updateCompletionQueryFromGuardrailsResponse(LLMClient.SingleCompletionQuery query, GuardrailRunner.CompletionQueryGuardrailResponse guardrailResp) {
        query.messages = guardrailResp.completionQuery.messages;
        query.context = guardrailResp.completionQuery.context;
    }

    public static void updateCompletionResponseFromGuardrailsResponse(LLMClient.SimpleCompletionResponseOrError resp, GuardrailRunner.CompletionResponseGuardrailResponse guardrailResp) {
        resp.text = guardrailResp.completionResponse.text;
    }

    public static void updateEmbeddingQueryFromGuardrailsResponse(LLMClient.EmbeddingQuery query, GuardrailRunner.EmbeddingQueryGuardrailResponse guardrailResp) {
        query.text = guardrailResp.embeddingQuery.text;
    }

    public static AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus getLicensing() {
        LicenseStatusService licenseStatusService = (LicenseStatusService)SpringUtils.getBean(LicenseStatusService.class);
        LicenseStatusService.LicensingStatus ls = licenseStatusService.getLicensingStatus();
        return LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
    }

    public static class LLMClientAuditReporter
    implements AutoCloseable {
        private final AuthCtx authCtx;
        private final LLMStructuredRef llmRef;
        private final LLMClient llmClient;
        private final AtomicLong queriesCount = new AtomicLong(0L);
        @Nullable
        private final AuditPrivilegedClient auditPrivilegedClient;

        public LLMClientAuditReporter(AuthCtx authCtx, LLMStructuredRef llmRef, LLMClient llmClient) {
            this.authCtx = authCtx;
            this.llmRef = llmRef;
            this.llmClient = llmClient;
            this.auditPrivilegedClient = ClusterSelector.getContext() != MainLoggingConfigurator.ProcessType.BACKEND ? new AuditPrivilegedClient() : null;
        }

        public void emitLLMCompletionAuditIfNeeded(LLMClient.SingleCompletionQuery singleCompletionQuery, LLMClient.SimpleCompletionResponseOrError responseOrError) {
            this.queriesCount.incrementAndGet();
            if (this.auditPrivilegedClient != null) {
                LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(this.authCtx, this.auditPrivilegedClient, this.llmRef, this.llmClient.getConnection(), singleCompletionQuery, responseOrError);
            } else {
                LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded((AuditTrailService)SpringUtils.getBean(AuditTrailService.class), this.llmRef, this.llmClient.getConnection(), singleCompletionQuery, responseOrError);
            }
        }

        private void reportCru() {
            ComputeResourceUsage totalCRU = this.llmClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION, this.llmRef);
            if (totalCRU != null) {
                totalCRU.llmUsage.totalQueries = this.queriesCount.get();
                totalCRU.llmUsage.cacheHitQueries = 0L;
                totalCRU.llmUsage.cacheMissQueries = this.queriesCount.get();
                ((ComputeResourceUsageReportingService)SpringUtils.getBean(ComputeResourceUsageReportingService.class)).reportComplete(totalCRU);
            }
        }

        @Override
        public void close() throws Exception {
            try {
                this.reportCru();
            }
            finally {
                if (this.auditPrivilegedClient != null) {
                    this.auditPrivilegedClient.close();
                }
            }
        }
    }
}

