/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.promptinjection;

import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.governance.SimpleNonModifyingGuardrailDetector;
import com.dataiku.dip.llm.governance.promptinjection.LLMAsAJudgePromptInjectionDetectionPipelineElement;
import com.dataiku.dip.llm.governance.promptinjection.LocalHuggingFacePromptInjectionDetectionPipeline;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import java.io.Closeable;
import java.util.EnumSet;

public class PromptInjectionDetectionGuardrail {
    public static final GuardrailMeta META = new GuardrailMeta(){

        @Override
        public String getType() {
            return "PromptInjectionDetector";
        }

        @Override
        public Class<? extends GuardrailParams> paramsClass() {
            return Params.class;
        }

        @Override
        public EnumSet<GuardrailMeta.GuardrailFlag> getFlags(GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            return EnumSet.of(GuardrailMeta.GuardrailFlag.OPERATES_ON_QUERIES);
        }

        @Override
        public GuardrailRunner buildRunner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
            return new Runner(authCtx, projectKey, elt);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.promptinjection");

    private static class Runner
    extends GuardrailRunner {
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        private final Params p;
        SimpleNonModifyingGuardrailDetector detector;

        Runner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.elt = elt;
            this.p = elt.getParamsCopyAs(Params.class);
            SpringUtils.getInstance().autowire((Object)this);
        }

        @Override
        public void init() throws Exception {
            AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = GuardrailsPipelineUtils.getLicensing();
            if (!featuresStatus.advancedLLMMeshAllowed) {
                throw new LicenseRestrictionException("Prompt Injection detection requires the \"Advanced LLM Mesh\" add-on");
            }
            switch (this.p.engine) {
                case PROMPT_INJECTION_CLASSIFIER: {
                    this.detector = new LocalHuggingFacePromptInjectionDetectionPipeline(this.authCtx, this.projectKey, this.p);
                    break;
                }
                case LLM_AS_A_JUDGE: {
                    this.detector = new LLMAsAJudgePromptInjectionDetectionPipelineElement(this.authCtx, this.projectKey, this.p);
                }
            }
        }

        @Override
        public void close() {
            IOUtils.closeQuietly((Closeable)this.detector, null);
        }

        @Override
        public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            this.detector.processCompletionQuery(query, trace);
            return GuardrailRunner.CompletionQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
        }

        @Override
        public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            throw new Error("unreachable");
        }

        @Override
        public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            return GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            this.detector.processImageGenerationQuery(query, trace);
            return GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
            return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
        }
    }

    public static class Params
    implements GuardrailParams {
        public PromptDetectionEngine engine = PromptDetectionEngine.PROMPT_INJECTION_CLASSIFIER;
        public LLMAsAJudgeMode llmAsAJudgeMode = LLMAsAJudgeMode.GENERAL_DETECTION;
        public String huggingFaceLocalConnectionName;
        public String huggingFaceLocalModelId;
        public String genericTextCompletionLlmId;
        public String customPromptForInjectionDetection;
        public float huggingFaceLocalThreshold = 0.1f;
    }

    public static enum LLMAsAJudgeMode {
        GENERAL_DETECTION,
        DETECTION_AGAINST_SYSTEM_PROMPT,
        CUSTOM;

    }

    public static enum PromptDetectionEngine {
        PROMPT_INJECTION_CLASSIFIER,
        LLM_AS_A_JUDGE;

    }
}

