/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.forbiddenterms;

import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.forbiddenterms.IForbiddenTermsService;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.google.re2j.Pattern;
import java.util.EnumSet;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class ForbiddenTermsDetectionGuardrail {
    public static final GuardrailMeta META = new GuardrailMeta(){

        @Override
        public String getType() {
            return "ForbiddenTermsDetector";
        }

        @Override
        public Class<? extends GuardrailParams> paramsClass() {
            return Params.class;
        }

        @Override
        public EnumSet<GuardrailMeta.GuardrailFlag> getFlags(GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            Params p = elt.getParamsCopyAs(Params.class);
            EnumSet<GuardrailMeta.GuardrailFlag> es = EnumSet.noneOf(GuardrailMeta.GuardrailFlag.class);
            if (p.filterQueries) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_QUERIES);
            }
            if (p.filterResponses) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES);
            }
            return es;
        }

        @Override
        public GuardrailRunner buildRunner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
            return new Runner(authCtx, projectKey, elt);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.forbiddenterms");

    private static GuardrailsPipelineRunner.LLMUsageEnforcerException buildEnforcerException(boolean isQuery, String contentSource, String forbiddenTerm) {
        return new GuardrailsPipelineRunner.LLMUsageEnforcerException(isQuery ? LLMClient.LLMResponseErrorSource.QUERY_FORBIDDEN_TERMS : LLMClient.LLMResponseErrorSource.RESPONSE_FORBIDDEN_TERMS, isQuery ? GuardrailsCodes.ERR_LLM_QUERY_FORBIDDEN_TERM : GuardrailsCodes.ERR_LLM_RESPONSE_FORBIDDEN_TERM, String.format("LLM %s denied: forbidden term %s found", contentSource, forbiddenTerm));
    }

    private static class Runner
    extends GuardrailRunner {
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        private final Params p;
        private Set<String> forbiddenTerms;

        Runner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.elt = elt;
            this.p = elt.getParamsCopyAs(Params.class);
        }

        @Override
        public void init() throws Exception {
            this.forbiddenTerms = ((IForbiddenTermsService)SpringUtils.getBean(IForbiddenTermsService.class)).getForbiddenTerms_NoCheck(this.authCtx, this.p);
        }

        @Override
        public void close() {
        }

        private void processContentChunk(String content, boolean isQuery, String contentSource) throws GuardrailsPipelineRunner.LLMUsageEnforcerException {
            if (StringUtils.isBlank((CharSequence)content)) {
                logger.info((Object)String.format("Skipping Forbidden terms detection as %s is empty", contentSource));
                return;
            }
            content = content.toLowerCase(Locale.ENGLISH);
            for (String forbiddenTerm : this.forbiddenTerms) {
                if (this.p.matchingMode == ForbiddenTermsMatchingMode.CONTAINS_IGNORE_CASE) {
                    if (!content.contains(forbiddenTerm)) continue;
                    throw ForbiddenTermsDetectionGuardrail.buildEnforcerException(isQuery, contentSource, forbiddenTerm);
                }
                if (this.p.matchingMode != ForbiddenTermsMatchingMode.REGEXP_MATCHES) continue;
                for (String forbidden : this.forbiddenTerms) {
                    logger.info((Object)("Checking for match of pattern" + forbidden + " vs " + content));
                    if (!Pattern.compile((String)forbidden, (int)1).matcher((CharSequence)content).matches()) continue;
                    throw ForbiddenTermsDetectionGuardrail.buildEnforcerException(isQuery, contentSource, forbiddenTerm);
                }
            }
        }

        @Override
        public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                String content = query.messages.stream().map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
                this.processContentChunk(content, true, "query");
            }
            return GuardrailRunner.CompletionQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterResponses) {
                this.processContentChunk(response.text, false, "response");
            }
            return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
        }

        @Override
        public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            throw new Error("unreachable");
        }

        @Override
        public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                this.processContentChunk(query.text, true, "query");
            }
            return GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws GuardrailsPipelineRunner.LLMUsageEnforcerException {
            if (this.p.filterQueries) {
                this.processContentChunk(query.getConcatenatedPrompts(), true, "query");
                this.processContentChunk(query.getConcatenatedNegativePrompts(), true, "query");
            }
            return GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
            return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
        }
    }

    public static class Params
    implements GuardrailParams {
        public boolean filterQueries = true;
        public boolean filterResponses = true;
        public ForbiddenTermsDetectionSource source = ForbiddenTermsDetectionSource.DATASET;
        public String datasetProject;
        public String datasetName;
        public String datasetColumn;
        public ForbiddenTermsMatchingMode matchingMode = ForbiddenTermsMatchingMode.CONTAINS_IGNORE_CASE;
    }

    public static enum ForbiddenTermsMatchingMode {
        CONTAINS_IGNORE_CASE,
        REGEXP_MATCHES;

    }

    public static enum ForbiddenTermsDetectionSource {
        DATASET;

    }
}

