/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.toxicity;

import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.SimpleNonModifyingGuardrailDetector;
import com.dataiku.dip.llm.governance.toxicity.ToxicityDetectionGuardrail;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Params;
import java.io.IOException;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;

public class OpenAIToxicityDetectionPipelineElement
implements SimpleNonModifyingGuardrailDetector {
    private RawOpenAIClient rawClient;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.governance.toxicity");

    public OpenAIToxicityDetectionPipelineElement(AuthCtx authCtx, ToxicityDetectionGuardrail.Params settings, @Nullable String projectKey) throws Exception {
        SpringUtils.getInstance().autowire((Object)this);
        if (StringUtils.isBlank((CharSequence)settings.openAIConnectionName)) {
            throw new IllegalArgumentException("This LLM connection uses a Toxicity detector that is missing its OpenAI connection name");
        }
        OpenAIConnection conn = (OpenAIConnection)ConnectionsDAO.get().getMandatoryConnection(authCtx, settings.openAIConnectionName);
        ICredentialsService.BasicCredential creds = conn.getFullyResolvedCredentials_sqlLike(new ConnectionWithBasicCredential.CredentialResolutionContext(authCtx, null), ICredentialsService.BasicCredential.class);
        Params connectionProperties = ConnectionUtils.getParamsFromProperties(conn.getDkuProperties());
        boolean trustAllSSLCertificates = connectionProperties.getBoolParam("dku.connection.llm.trustAllSSLCertificates", false);
        boolean forceContentLength = connectionProperties.getBoolParam("dku.connection.llm.forceContentLength", false);
        this.rawClient = RawOpenAIClient.forOpenAI(conn.params.customURL, creds.password, conn.params.organizationId, conn.params.customHeaders, projectKey, conn.params.networkSettings, conn.getProxySettings(), trustAllSSLCertificates, conn.useMaxCompletionToken(), forceContentLength, authCtx, conn);
    }

    @Override
    public void close() throws IOException {
        try {
            this.rawClient.close();
        }
        catch (Exception e) {
            throw new IOException("Failed to close LLM client", e);
        }
    }

    @Override
    public void processCompletionQuery(LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws IOException, GuardrailsPipelineRunner.LLMUsageEnforcerException {
        String content = query.messages.stream().map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
        this.checkToxicity(content, true, "query", trace);
    }

    @Override
    public void processCompletionResponse(LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws IOException, GuardrailsPipelineRunner.LLMUsageEnforcerException {
        this.checkToxicity(response.text, false, "response", trace);
    }

    @Override
    public void processEmbeddingQuery(LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws GuardrailsPipelineRunner.LLMUsageEnforcerException, IOException {
        this.checkToxicity(query.text, true, "query", trace);
    }

    @Override
    public void processImageGenerationQuery(LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws IOException, GuardrailsPipelineRunner.LLMUsageEnforcerException {
        this.checkToxicity(query.getConcatenatedPrompts(), true, "query", trace);
        this.checkToxicity(query.getConcatenatedNegativePrompts(), true, "query", trace);
    }

    private void checkToxicity(String content, boolean isQuery, String contentSource, LLMClient.LLMMeshTraceSpan trace) throws IOException, GuardrailsPipelineRunner.LLMUsageEnforcerException {
        logger.info((Object)("Detecting toxicity on " + content));
        if (StringUtils.isBlank((CharSequence)content)) {
            logger.info((Object)String.format("Skipping Toxicity detection as %s is empty", contentSource));
            return;
        }
        try (LLMClient.LLMMeshTraceSpan callSpan = trace.withChildSpan("DKU_LLM_MESH_TOXICITY_OPENAI_CALL");){
            RawOpenAIClient.SimpleModerationResponse resp = this.rawClient.moderate(content);
            if (resp.flagged) {
                throw new GuardrailsPipelineRunner.LLMUsageEnforcerException(isQuery ? LLMClient.LLMResponseErrorSource.QUERY_TOXICITY_DETECTION : LLMClient.LLMResponseErrorSource.RESPONSE_TOXICITY_DETECTION, isQuery ? GuardrailsCodes.ERR_LLM_QUERY_TOXIC : GuardrailsCodes.ERR_LLM_RESPONSE_TOXIC, String.format("LLM %s denied: flagged by content moderation: %s", contentSource, JSON.json(resp.flaggedCategories)));
            }
        }
    }
}

