/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm;

import com.dataiku.common.audit.AuditContextBase;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditObj;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;

public class LLMAuditHelper {
    public static void emitLLMCompletionAuditFromBackendIfNeeded(AuditTrailService auditTrailService, LLMStructuredRef llmRef, @Nullable AbstractLLMConnection connection, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response) {
        LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(auditTrailService, llmRef, connection, query, response, Collections.emptyList());
    }

    public static void emitLLMCompletionAuditFromBackendIfNeeded(AuditTrailService auditTrailService, LLMStructuredRef llmRef, @Nullable AbstractLLMConnection connection, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, Collection<String> savedInputImagePaths) {
        if (LLMAuditHelper.getAuditingMode(connection) == AbstractLLMConnection.LLMQueriesAuditingMode.NONE) {
            return;
        }
        String prompt = query.messages.stream().map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
        AuditTrailService.EmittableAuditObj obj = auditTrailService.generic("llm-completion-query");
        LLMAuditHelper.fillLLMCompletionAudit((AuditObj)obj, llmRef, connection, prompt, response, savedInputImagePaths);
        obj.emit();
    }

    public static void emitLLMCompletionAuditFromJobIfNeeded(AuthCtx authCtx, AuditPrivilegedClient auditClient, LLMStructuredRef llmRef, @Nullable AbstractLLMConnection llmConnection, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response) {
        LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(authCtx, auditClient, llmRef, llmConnection, query, response, Collections.emptyList());
    }

    public static void emitLLMCompletionAuditFromJobIfNeeded(AuthCtx authCtx, AuditPrivilegedClient auditClient, LLMStructuredRef llmRef, @Nullable AbstractLLMConnection llmConnection, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, Collection<String> savedInputImagePaths) {
        if (LLMAuditHelper.getAuditingMode(llmConnection) == AbstractLLMConnection.LLMQueriesAuditingMode.NONE) {
            return;
        }
        String prompt = query.messages.stream().map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
        AuditObj obj = new AuditObj("generic", "llm-completion-query");
        AuditContextBase.fillAuthCtx((AuthCtx)authCtx, (JsonObject)obj.get());
        LLMAuditHelper.fillLLMCompletionAudit(obj, llmRef, llmConnection, prompt, response, savedInputImagePaths);
        auditClient.publish(obj);
    }

    private static void fillLLMCompletionAudit(AuditObj obj, LLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, String prompt, LLMClient.SimpleCompletionResponseOrError response, Collection<String> savedInputImagePaths) {
        boolean auditTrace;
        LLMAuditHelper.fillBaseLLMInfo(obj, enrichedLLMRef);
        if (response.promptTokens != null) {
            obj.with("promptTokens", (Number)response.promptTokens);
        }
        if (response.completionTokens != null) {
            obj.with("completionTokens", (Number)response.completionTokens);
        }
        if (response.totalTokens != null) {
            obj.with("totalTokens", (Number)response.totalTokens);
        }
        if (response.tokenCountsAreEstimated != null && response.tokenCountsAreEstimated.booleanValue()) {
            obj.with("tokenCountsAreEstimated", true);
        }
        if (response.estimatedCost != null && !response.fromCache) {
            obj.with("estimatedCostUSD", (Number)response.estimatedCost);
        }
        obj.with("fromCache", response.fromCache);
        if (LLMAuditHelper.getAuditingMode(connection) == AbstractLLMConnection.LLMQueriesAuditingMode.FULL_DATA) {
            obj.with("prompt", prompt).with("response", response.text);
        }
        if (!savedInputImagePaths.isEmpty()) {
            obj.with("savedInputImagePaths", JSON.toJsonArray(savedInputImagePaths));
        }
        if (StringUtils.isNotBlank((CharSequence)response.errorMessage)) {
            obj.with("errorMessage", response.errorMessage);
        }
        if (response.errorSource != null) {
            obj.with("errorSource", response.errorSource.toString());
        }
        if (StringUtils.isNotBlank((CharSequence)response.errorCode)) {
            obj.with("errorCode", response.errorCode);
        }
        LLMAuditHelper.fillAuditObjFromCRUContext(obj);
        LLMAuditHelper.fillImageAuditFolderRef(obj, connection);
        boolean bl = auditTrace = connection == null ? ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.globalLLMAuditSettings.auditCompletionTracesForLLMsWithoutConnection : connection.getLLMConnectionParams().auditCompletionTraces;
        if (response.trace != null && auditTrace) {
            obj.with("trace", JSON.toJsonObject((Object)response.trace));
        }
        if (response.guardrailsAuditData != null) {
            obj.with("guardrailsAuditData", response.guardrailsAuditData);
        }
    }

    public static void emitLLMEmbeddingAudit(AuditTrailService auditTrailService, EnrichedLLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, ComputeResourceUsage.LLMUsageType llmUsageType, @Nullable String text, LLMClient.SimpleEmbeddingResponseOrError response) {
        if (LLMAuditHelper.getAuditingMode(connection) == AbstractLLMConnection.LLMQueriesAuditingMode.NONE) {
            return;
        }
        AuditTrailService.EmittableAuditObj obj = auditTrailService.generic("llm-embedding-query");
        LLMAuditHelper.fillLLMEmbeddingAudit((AuditObj)obj, enrichedLLMRef, connection, llmUsageType, text, response);
        obj.emit();
    }

    private static void fillLLMEmbeddingAudit(AuditObj obj, LLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, ComputeResourceUsage.LLMUsageType llmUsageType, @Nullable String text, LLMClient.SimpleEmbeddingResponseOrError response) {
        LLMAuditHelper.fillBaseLLMInfo(obj, enrichedLLMRef).with("llmUsageType", llmUsageType.toString());
        if (LLMAuditHelper.getAuditingMode(connection) == AbstractLLMConnection.LLMQueriesAuditingMode.FULL_DATA && text != null) {
            obj.with("text", text);
        }
        LLMAuditHelper.fillAuditObjFromCRUContext(obj);
        if (response.guardrailsAuditData != null) {
            obj.with("guardrailsAuditData", response.guardrailsAuditData);
        }
    }

    public static void emitLLMImageGenerationAudit(AuditTrailService auditTrailService, LLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response) {
        LLMAuditHelper.emitLLMImageGenerationAudit(auditTrailService, enrichedLLMRef, connection, query, response, null, Collections.emptyList());
    }

    public static void emitLLMImageGenerationAudit(AuditTrailService auditTrailService, LLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, @Nullable String inputImagePath, List<String> outputImagePaths) {
        if (LLMAuditHelper.getAuditingMode(connection) == AbstractLLMConnection.LLMQueriesAuditingMode.NONE) {
            return;
        }
        AuditTrailService.EmittableAuditObj obj = auditTrailService.generic("llm-image-generation-query");
        LLMAuditHelper.fillLLMImageGenerationAudit((AuditObj)obj, enrichedLLMRef, connection, query, response, inputImagePath, outputImagePaths);
        obj.emit();
    }

    private static void fillLLMImageGenerationAudit(AuditObj obj, LLMStructuredRef enrichedLLMRef, AbstractLLMConnection connection, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, String inputImagePath, List<String> outputImagePaths) {
        boolean auditTrace;
        LLMAuditHelper.fillBaseLLMInfo(obj, enrichedLLMRef);
        if (response.estimatedCost > 0.0) {
            obj.with("estimatedCostUSD", (Number)response.estimatedCost);
        }
        LLMClient.ImageGenerationQuery auditedQuery = query.audit(inputImagePath);
        obj.with("query", JSON.toJsonObject((Object)auditedQuery));
        if (StringUtils.isNotBlank((CharSequence)response.errorMessage)) {
            obj.with("errorMessage", response.errorMessage);
        } else {
            LLMClient.ImageGenerationResponse auditResponse = response.audit(outputImagePaths);
            obj.with("response", JSON.toJsonObject((Object)auditResponse));
        }
        LLMAuditHelper.fillImageAuditFolderRef(obj, connection);
        LLMAuditHelper.fillAuditObjFromCRUContext(obj);
        boolean bl = auditTrace = connection == null ? ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.globalLLMAuditSettings.auditImageGenerationTracesForLLMsWithoutConnection : connection.getLLMConnectionParams().auditImageGenerationTraces;
        if (response.trace != null && auditTrace) {
            obj.with("trace", JSON.toJsonObject((Object)response.trace));
        }
        if (response.guardrailsAuditData != null) {
            obj.with("guardrailsAuditData", response.guardrailsAuditData);
        }
    }

    private static AbstractLLMConnection.LLMQueriesAuditingMode getAuditingMode(AbstractLLMConnection connection) {
        if (connection == null) {
            return ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.globalLLMAuditSettings.auditingModeForLLMsWithoutConnection;
        }
        return connection.getLLMConnectionParams().auditingMode;
    }

    private static AuditObj fillBaseLLMInfo(AuditObj obj, LLMStructuredRef enrichedLLMRef) {
        return obj.with("llmId", enrichedLLMRef.encodeToId()).with("llmConnection", enrichedLLMRef.connection).with("llmType", enrichedLLMRef.type.toString()).with("llmModel", enrichedLLMRef.getModelNameForAudit());
    }

    private static void fillAuditObjFromCRUContext(AuditObj obj) {
        ComputeResourceUsageContext cruContext = CurrentComputeResourceUsageContext.get();
        if (cruContext != null) {
            switch (cruContext.type) {
                case JOB_ACTIVITY: {
                    obj.with("originContext", "JOB_ACTIVITY").with("projectKey", cruContext.projectKey).with("jobId", cruContext.jobId).with("jobActivityId", cruContext.activityId);
                    break;
                }
                case PROMPT_STUDIO: {
                    obj.with("originContext", "PROMPT_STUDIO").with("projectKey", cruContext.projectKey).with("promptStudioId", cruContext.promptStudioId);
                    break;
                }
            }
        }
    }

    private static void fillImageAuditFolderRef(AuditObj obj, AbstractLLMConnection connection) {
        if (connection == null) {
            return;
        }
        obj.with("storeImages", connection.getLLMConnectionParams().storeImages);
        obj.with("imageAuditManagedFolderRef", connection.getLLMConnectionParams().imageAuditManagedFolderRef);
    }
}

