/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.toxicity;

import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.governance.SentenceBufferingStreamFilter;
import com.dataiku.dip.llm.governance.SimpleNonModifyingGuardrailDetector;
import com.dataiku.dip.llm.governance.toxicity.LocalHuggingFaceToxicityDetectionPipeline;
import com.dataiku.dip.llm.governance.toxicity.OpenAIToxicityDetectionPipelineElement;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import java.io.Closeable;
import java.util.EnumSet;
import java.util.stream.Collectors;

public class ToxicityDetectionGuardrail {
    public static final GuardrailMeta META = new GuardrailMeta(){

        @Override
        public String getType() {
            return "ToxicityDetector";
        }

        @Override
        public Class<? extends GuardrailParams> paramsClass() {
            return Params.class;
        }

        @Override
        public EnumSet<GuardrailMeta.GuardrailFlag> getFlags(GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            Params p = elt.getParamsCopyAs(Params.class);
            EnumSet<GuardrailMeta.GuardrailFlag> es = EnumSet.noneOf(GuardrailMeta.GuardrailFlag.class);
            if (p.filterQueries) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_QUERIES);
            }
            if (p.filterResponses) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES);
                if (p.retryOnToxicResponses) {
                    es.add(GuardrailMeta.GuardrailFlag.MAY_REQUEST_RETRY_ON_RESPONSES);
                }
                if (p.sentenceBasedResponseProcessing) {
                    es.add(GuardrailMeta.GuardrailFlag.CAN_STREAM_RESPONSES);
                }
            }
            return es;
        }

        @Override
        public GuardrailRunner buildRunner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
            return new Runner(authCtx, projectKey, elt);
        }
    };

    private static class Runner
    extends GuardrailRunner {
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        private final Params p;
        SimpleNonModifyingGuardrailDetector detector;

        Runner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.elt = elt;
            this.p = elt.getParamsCopyAs(Params.class);
        }

        @Override
        public void init() throws Exception {
            switch (this.p.engine) {
                case OPENAI_API: {
                    this.detector = new OpenAIToxicityDetectionPipelineElement(this.authCtx, this.p, this.projectKey);
                    break;
                }
                case HUGGINGFACE_LOCAL: {
                    AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = GuardrailsPipelineUtils.getLicensing();
                    if (!featuresStatus.advancedLLMMeshAllowed) {
                        throw new LicenseRestrictionException("Local Hugging Face toxicity detection requires the \"Advanced LLM Mesh\" add-on");
                    }
                    this.detector = new LocalHuggingFaceToxicityDetectionPipeline(this.authCtx, this.projectKey, this.p);
                }
            }
        }

        @Override
        public void close() {
            IOUtils.closeQuietly((Closeable)this.detector, null);
        }

        @Override
        public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                this.detector.processCompletionQuery(query, trace);
            }
            return GuardrailRunner.CompletionQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (!this.p.filterResponses) {
                return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
            }
            try {
                this.detector.processCompletionResponse(query, response, trace);
                return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
            }
            catch (GuardrailsPipelineRunner.LLMUsageEnforcerException e) {
                GuardrailRunner.CompletionResponseGuardrailResponse ret = GuardrailRunner.CompletionResponseGuardrailResponse.fail(context, (Throwable)((Object)e));
                if (this.p.retryOnToxicResponses) {
                    ret.action = GuardrailRunner.ResponseGuardrailAction.RETRY;
                    ret.updatedMessagesForRetry = query.messages.stream().map(m -> (LLMClient.ChatMessage)JSON.deepCopy((Object)m)).collect(Collectors.toList());
                    ret.updatedMessagesForRetry.add(new LLMClient.ChatMessage("assistant", response.text));
                    ret.updatedMessagesForRetry.add(new LLMClient.ChatMessage("system", "Your last answer was not acceptable, as it contained toxic content. Please rewrite."));
                }
                return ret;
            }
        }

        @Override
        public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            assert (this.p.filterResponses && this.p.sentenceBasedResponseProcessing);
            StreamedToxicityDetector std = new StreamedToxicityDetector(underlying, query, trace);
            return new SentenceBufferingStreamFilter(std);
        }

        @Override
        public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                this.detector.processEmbeddingQuery(query, trace);
            }
            return GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                this.detector.processImageGenerationQuery(query, trace);
            }
            return GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
            return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
        }

        class StreamedToxicityDetector
        extends LLMClient.StreamedCompletionResponseFilter {
            private final LLMClient.LLMMeshTraceSpan trace;
            private final LLMClient.SingleCompletionQuery q;

            public StreamedToxicityDetector(LLMClient.StreamedCompletionResponseConsumer underlying, LLMClient.SingleCompletionQuery q, LLMClient.LLMMeshTraceSpan trace) {
                super(underlying);
                this.trace = trace;
                this.q = q;
            }

            @Override
            public void onStreamChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
                if (chunk.text == null) {
                    this.underlying.onStreamChunk(chunk);
                } else {
                    LLMClient.SimpleCompletionResponseOrError fakeResp = LLMClient.SimpleCompletionResponseOrError.blank();
                    fakeResp.text = chunk.text;
                    Runner.this.detector.processCompletionResponse(this.q, fakeResp, this.trace);
                    this.underlying.onStreamChunk(chunk);
                }
            }
        }
    }

    public static class Params
    implements GuardrailParams {
        public boolean filterQueries = true;
        public boolean filterResponses = true;
        public boolean sentenceBasedResponseProcessing;
        public ToxicityDetectionEngine engine = ToxicityDetectionEngine.OPENAI_API;
        public String openAIConnectionName;
        public String huggingFaceLocalConnectionName;
        public String huggingFaceLocalModelId;
        public float huggingFaceLocalThreshold = 0.5f;
        public boolean retryOnToxicResponses = false;
    }

    public static enum ToxicityDetectionEngine {
        OPENAI_API,
        HUGGINGFACE_LOCAL;

    }
}

