/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance;

import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.cuspol.CustomPolicyHooksRegistry;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsBypassTokensService;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsRegistry;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.util.JsonUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonArray;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import javax.annotation.Nullable;
import org.springframework.beans.factory.annotation.Autowired;

public class GuardrailsPipelineRunner
implements AutoCloseable {
    @Autowired
    private CustomPolicyHooksRegistry customPolicyHooksRegistry;
    @Autowired
    private GuardrailsBypassTokensService bypassTokensService;
    private final String bypassToken;
    private final AuthCtx authCtx;
    private final String contextProjectKey;
    private final GuardrailsPipelineSettings settings;
    private List<GuardrailRunnerWithContext> runners = new ArrayList<GuardrailRunnerWithContext>();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.runner");

    public GuardrailsPipelineRunner(AuthCtx authCtx, String contextProjectKey, @Nullable GuardrailsPipelineSettings settings) throws Exception {
        this.authCtx = authCtx;
        this.settings = settings;
        this.contextProjectKey = contextProjectKey;
        SpringUtils.getInstance().autowire((Object)this);
        this.bypassToken = this.bypassTokensService.newBypassToken();
        if (settings != null) {
            for (GuardrailsPipelineSettings.GuardrailsPipelineElement elt : settings.guardrails) {
                if (!elt.enabled) continue;
                GuardrailMeta meta = GuardrailsRegistry.getMeta(elt.type);
                GuardrailRunner runner = meta.buildRunner(authCtx, contextProjectKey, elt, this.bypassToken);
                GuardrailRunnerWithContext grwc = new GuardrailRunnerWithContext();
                grwc.runner = runner;
                grwc.elt = elt;
                grwc.flags = meta.getFlags(elt);
                logger.info((Object)("Adding runner " + grwc.elt.type + " flags=" + JSON.json(grwc.flags)));
                this.runners.add(grwc);
                grwc.runner.init();
            }
        }
    }

    @Override
    public void close() throws Exception {
        this.bypassTokensService.invalidateBypassToken(this.bypassToken);
        for (GuardrailRunnerWithContext runner : this.runners) {
            DKUtils.closeACWithLog((AutoCloseable)runner.runner, (String)runner.elt.type);
        }
    }

    /*
     * Exception decompiling
     */
    public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.LLMMeshTraceSpan rootSpan) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * Exception decompiling
     */
    public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError completionResponseOrError, LLMClient.CompletionSettings settings, LLMClient.LLMMeshTraceSpan rootSpan) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        LLMClient.StreamedCompletionResponseConsumer cur = underlying;
        for (GuardrailRunnerWithContext r : this.runners) {
            if (!r.flags.contains((Object)GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES)) continue;
            assert (r.flags.contains((Object)GuardrailMeta.GuardrailFlag.CAN_STREAM_RESPONSES));
            cur = r.runner.newStreamedCompletionResponseHandler(cur, context, query, trace);
        }
        return cur;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings, LLMClient.LLMMeshTraceSpan rootSpan) throws Exception {
        TransactionContext.warnAttachedTransaction();
        this.customPolicyHooksRegistry.onAboutToSendLLMQuery(this.authCtx, null, this.contextProjectKey, null, query);
        JsonArray aggregatedAuditData = null;
        GuardrailRunner.EmbeddingQueryGuardrailResponse lastEnforcerResponse = null;
        int i = 0;
        for (GuardrailRunnerWithContext runner : this.runners) {
            LLMClient.LLMMeshTraceSpan span = rootSpan.withChildSpan("GUARDRAIL_" + runner.elt.type + "_" + i);
            try {
                ++i;
                try {
                    logger.info((Object)("Processing query with guardrail " + runner.elt.type));
                    lastEnforcerResponse = runner.runner.processEmbeddingQuery(context, query, span);
                    logger.debug((Object)("Guardrail " + runner.elt.type + " response:" + String.valueOf((Object)lastEnforcerResponse.action)));
                    switch (lastEnforcerResponse.action) {
                        case PASS: {
                            break;
                        }
                        case FAIL: {
                            GuardrailRunner.EmbeddingQueryGuardrailResponse embeddingQueryGuardrailResponse = lastEnforcerResponse;
                            return embeddingQueryGuardrailResponse;
                        }
                        case PASS_WITH_AUDIT: {
                            aggregatedAuditData = JsonUtils.appendIntoJsonArray(aggregatedAuditData, lastEnforcerResponse.auditData);
                            break;
                        }
                        case RESPOND: {
                            throw new IllegalArgumentException("An embedding guardrail can't return RESPOND");
                        }
                    }
                    context = lastEnforcerResponse.context;
                    query = lastEnforcerResponse.embeddingQuery;
                }
                catch (Exception e) {
                    logger.warn((Object)"Guardrail failed, rejecting the query", (Throwable)e);
                    GuardrailRunner.EmbeddingQueryGuardrailResponse embeddingQueryGuardrailResponse = GuardrailRunner.EmbeddingQueryGuardrailResponse.fail(context, e);
                    if (span == null) return embeddingQueryGuardrailResponse;
                    span.close();
                    return embeddingQueryGuardrailResponse;
                }
            }
            finally {
                if (span == null) continue;
                span.close();
            }
        }
        if (lastEnforcerResponse == null) {
            lastEnforcerResponse = GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }
        if (aggregatedAuditData == null) return lastEnforcerResponse;
        lastEnforcerResponse.auditData = aggregatedAuditData;
        return lastEnforcerResponse;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan rootSpan) throws Exception {
        TransactionContext.warnAttachedTransaction();
        JsonArray aggregatedAuditData = null;
        GuardrailRunner.ImageGenerationQueryGuardrailResponse lastEnforcerResponse = null;
        int i = 0;
        for (GuardrailRunnerWithContext runner : this.runners) {
            LLMClient.LLMMeshTraceSpan span = rootSpan.withChildSpan("GUARDRAIL_" + runner.elt.type + "_" + i);
            try {
                ++i;
                try {
                    logger.info((Object)("Processing query with guardrail " + runner.elt.type));
                    lastEnforcerResponse = runner.runner.processImageGenerationQuery(context, query, span);
                    logger.debug((Object)("Guardrail " + runner.elt.type + " response:" + String.valueOf((Object)lastEnforcerResponse.action)));
                    switch (lastEnforcerResponse.action) {
                        case PASS: {
                            break;
                        }
                        case FAIL: {
                            GuardrailRunner.ImageGenerationQueryGuardrailResponse imageGenerationQueryGuardrailResponse = lastEnforcerResponse;
                            return imageGenerationQueryGuardrailResponse;
                        }
                        case PASS_WITH_AUDIT: {
                            aggregatedAuditData = JsonUtils.appendIntoJsonArray(aggregatedAuditData, lastEnforcerResponse.auditData);
                            break;
                        }
                        case RESPOND: {
                            throw new IllegalArgumentException("An image generation guardrail can't return RESPOND");
                        }
                    }
                    context = lastEnforcerResponse.context;
                    query = lastEnforcerResponse.imageGenerationQuery;
                }
                catch (Exception e) {
                    logger.warn((Object)"Guardrail failed, rejecting the query", (Throwable)e);
                    GuardrailRunner.ImageGenerationQueryGuardrailResponse imageGenerationQueryGuardrailResponse = GuardrailRunner.ImageGenerationQueryGuardrailResponse.fail(context, e);
                    if (span == null) return imageGenerationQueryGuardrailResponse;
                    span.close();
                    return imageGenerationQueryGuardrailResponse;
                }
            }
            finally {
                if (span == null) continue;
                span.close();
            }
        }
        if (lastEnforcerResponse == null) {
            lastEnforcerResponse = GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }
        if (aggregatedAuditData == null) return lastEnforcerResponse;
        lastEnforcerResponse.auditData = aggregatedAuditData;
        return lastEnforcerResponse;
    }

    /*
     * Exception decompiling
     */
    public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError imageGenerationResponse, LLMClient.LLMMeshTraceSpan rootSpan) throws Exception {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Started 2 blocks at once
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.getStartingBlocks(Op04StructuredStatement.java:412)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:487)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private static class GuardrailRunnerWithContext {
        GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        EnumSet<GuardrailMeta.GuardrailFlag> flags;
        GuardrailRunner runner;

        private GuardrailRunnerWithContext() {
        }
    }

    public static final class RequestKind
    extends Enum<RequestKind> {
        private static final /* synthetic */ RequestKind[] $VALUES;

        public static RequestKind[] values() {
            return (RequestKind[])$VALUES.clone();
        }

        public static RequestKind valueOf(String name) {
            return Enum.valueOf(RequestKind.class, name);
        }

        private static /* synthetic */ RequestKind[] $values() {
            return new RequestKind[0];
        }

        static {
            $VALUES = RequestKind.$values();
        }
    }

    public static class LLMUsageEnforcerException
    extends CodedException {
        public LLMClient.LLMResponseErrorSource errorSource;

        public LLMUsageEnforcerException(LLMClient.LLMResponseErrorSource errorSource, InfoMessage.MessageCode code, String message) {
            super(code, message);
            this.errorSource = errorSource;
        }
    }
}

