/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.pii;

import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernelFactory;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.io.commands.ProcessSingleEmbeddingCommand;
import com.dataiku.dip.llm.io.commands.ProcessSingleImageGenerationCommand;
import com.dataiku.dip.llm.io.commands.ProcessSinglePromptCommand;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.pii.PIIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.Strings;
import com.dataiku.j2py.annotations.PyModel;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.apache.log4j.Logger;

public class PresidioBasedPIIHandlingServer
implements PIIClient {
    private final DSSAuthCtx authCtx;
    private final String projectKey;
    private final String envName;
    private final PresidioBasedPIIHandlingSettings settings;
    private final String kernelId;
    private SimplePythonKernel kernel;
    private final String containerConfName;
    private static final Logger logger = Logger.getLogger((String)"dku.llm.pii.server");

    public PresidioBasedPIIHandlingServer(AuthCtx authCtx, String projectKey, PresidioBasedPIIHandlingSettings settings, String envName, String containerConfName) {
        this.authCtx = (DSSAuthCtx)authCtx;
        this.projectKey = projectKey;
        this.settings = settings;
        this.envName = envName;
        this.containerConfName = containerConfName;
        this.kernelId = "pii-presidio-" + SecretKeyGenerator.generateSmall();
    }

    public void close() {
        if (this.kernel != null) {
            try {
                this.kernel.close();
            }
            catch (Throwable e) {
                logger.error((Object)"Failed to kill kernel", e);
            }
        }
    }

    public String getKernelId() {
        return this.kernelId;
    }

    @Override
    public CompletableFuture<PIIClient.CompletionQueryPIIDetectionResponse> processAsync(LLMClient.SingleCompletionQuery completionQuery) {
        return this.kernel.getAsyncLink().asyncSendRequest((Object)new ProcessSinglePromptCommand(completionQuery, null, false), PIIClient.CompletionQueryPIIDetectionResponse.class, PresidioBasedPIIHandlingServer::piiErrorMapper);
    }

    @Override
    public CompletableFuture<PIIClient.CompletionResponsePIIDetectionResponse> processAsync(LLMClient.SimpleCompletionResponseOrError completionResponse) {
        if (Strings.isNullOrEmpty((String)completionResponse.text)) {
            logger.info((Object)("Skipping PII detection of empty completionResponse: " + JSON.log((Object)completionResponse)));
            PIIClient.CompletionResponsePIIDetectionResponse ret = new PIIClient.CompletionResponsePIIDetectionResponse();
            ret.redactedResponse = completionResponse;
            return CompletableFuture.completedFuture(ret);
        }
        ProcessCompletionResponseCommand preq = new ProcessCompletionResponseCommand();
        preq.completionResponse = completionResponse;
        return this.kernel.getAsyncLink().asyncSendRequest((Object)preq, PIIClient.CompletionResponsePIIDetectionResponse.class, PresidioBasedPIIHandlingServer::piiErrorMapper);
    }

    @Override
    public CompletableFuture<PIIClient.EmbeddingQueryPIIDetectionResponse> processAsync(LLMClient.EmbeddingQuery embeddingQuery) {
        if (Strings.isNullOrEmpty((String)embeddingQuery.text)) {
            logger.info((Object)("Skipping PII detection of empty embedding query: " + JSON.log((Object)embeddingQuery.getSafeForLoggingCopy())));
            PIIClient.EmbeddingQueryPIIDetectionResponse ret = new PIIClient.EmbeddingQueryPIIDetectionResponse();
            ret.redactedQuery = embeddingQuery;
            return CompletableFuture.completedFuture(ret);
        }
        return this.kernel.getAsyncLink().asyncSendRequest((Object)new ProcessSingleEmbeddingCommand(embeddingQuery, null), PIIClient.EmbeddingQueryPIIDetectionResponse.class, PresidioBasedPIIHandlingServer::piiErrorMapper);
    }

    @Override
    public CompletableFuture<PIIClient.ImageGenerationQueryPIIDetectionResponse> processAsync(LLMClient.ImageGenerationQuery imageGenQuery) {
        return this.kernel.getAsyncLink().asyncSendRequest((Object)new ProcessSingleImageGenerationCommand(imageGenQuery), PIIClient.ImageGenerationQueryPIIDetectionResponse.class, PresidioBasedPIIHandlingServer::piiErrorMapper);
    }

    public boolean isAlive() {
        return this.kernel != null && this.kernel.isAlive();
    }

    public void start() throws Exception {
        this.kernel = SimplePythonKernelFactory.prepareKernel(this.authCtx, this.projectKey, GeneralSettingsDAO.CGrouppableProcessType.ML_KERNEL, this.envName, "dataiku.llm.pii.presidio_server", false, this.containerConfName, this.kernelId);
        this.kernel.start();
        this.kernel.getAsyncLink().request((Object)new StartCommand(this.settings), JsonElement.class);
    }

    private static Throwable piiErrorMapper(Throwable t) {
        return new GuardrailsPipelineRunner.LLMUsageEnforcerException(LLMClient.LLMResponseErrorSource.QUERY_PII_DETECTION, GuardrailsCodes.ERR_LLM_QUERY_PII, "PII detection rejected query: " + t.getMessage());
    }

    public static class PresidioBasedPIIHandlingSettings {
        public boolean enabled;
        public String language;
        public List<String> supportedLanguages = Lists.newArrayList((Object[])new String[]{"en", "fr", "de", "nl", "es", "it"});
        public double confidenceThreshold = 0.2;
        public PIIDetectionAction detectionAction = PIIDetectionAction.FAIL;
        public AnalyzerDetectionEntitiesMode entitiesMode = AnalyzerDetectionEntitiesMode.ALL;
        public PIIDetectionUnsupportedLangueAction unsupportedLanguageAction = PIIDetectionUnsupportedLangueAction.IGNORE;
        public int charsToMask = 12;
        public List<String> includedEntities = new ArrayList<String>();
        public List<String> excludedEntities = new ArrayList<String>();
    }

    @PyModel
    public static class ProcessCompletionResponseCommand {
        public final String type = "process-completion-response";
        public LLMClient.SimpleCompletionResponseOrError completionResponse;
    }

    public static class StartCommand {
        public final String type = "start";
        public final PresidioBasedPIIHandlingSettings settings;

        public StartCommand(PresidioBasedPIIHandlingSettings settings) {
            this.settings = settings;
        }
    }

    public static enum AnalyzerDetectionEntitiesMode {
        ALL,
        EXPLICIT_INCLUDE,
        EXPLICIT_EXCLUDE;

    }

    public static enum PIIDetectionUnsupportedLangueAction {
        FAIL,
        IGNORE;

    }

    public static enum PIIDetectionAction {
        FAIL,
        REPLACE,
        HASH,
        REDACT,
        MASK,
        FLAG_ONLY;

    }
}

