/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.custom;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.io.CustomPythonKernelException;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsRegistry;
import com.dataiku.dip.llm.governance.custom.CustomGuardrailMeta;
import com.dataiku.dip.llm.governance.custom.CustomGuardrailsService;
import com.dataiku.dip.llm.governance.custom.LoadedCustomGuardrail;
import com.dataiku.dip.llm.governance.custom.PythonGuardrailServer;
import com.dataiku.dip.llm.governance.custom.PythonGuardrailServerAPI;
import com.dataiku.dip.llm.governance.custom.PythonGuardrailServerPool;
import com.dataiku.dip.llm.langchain.PythonLLMServer;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.plugins.IPluginsRegistryService;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import com.google.gson.JsonObject;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;

public class CustomGuardrailRunner
extends GuardrailRunner {
    private final DSSAuthCtx authCtx;
    private final String contextProjectKey;
    private final CustomGuardrailMeta meta;
    private final CustomGuardrailParams params;
    private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
    private final String bypassToken;
    protected PythonGuardrailServerAPI serverAPI;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.custom");

    @Override
    public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.initOnce();
        PythonGuardrailServerAPI.Query q = new PythonGuardrailServerAPI.Query();
        q.context = context;
        q.completionQuery = query;
        q.bypassToken = this.bypassToken;
        PythonGuardrailServerAPI.Response r = this.serverAPI.processAsync(q).get();
        GuardrailRunner.CompletionQueryGuardrailResponse ret = new GuardrailRunner.CompletionQueryGuardrailResponse();
        ret.completionQuery = r.completionQuery;
        ret.action = r.queryGuardrailResponse.action;
        ret.error = r.queryGuardrailResponse.error;
        ret.context = r.context;
        ret.auditData = r.queryGuardrailResponse.auditData;
        ret.overriddenResponseText = r.queryGuardrailResponse.overriddenResponseText;
        if (r.trace != null) {
            trace.addObservation(r.trace);
        }
        this.fixupErrorIfNeeded(ret);
        return ret;
    }

    @Override
    public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.initOnce();
        PythonGuardrailServerAPI.Query q = new PythonGuardrailServerAPI.Query();
        q.context = context;
        q.completionQuery = query;
        q.completionResponse = response;
        q.bypassToken = this.bypassToken;
        PythonGuardrailServerAPI.Response r = this.serverAPI.processAsync(q).get();
        GuardrailRunner.CompletionResponseGuardrailResponse ret = new GuardrailRunner.CompletionResponseGuardrailResponse();
        ret.completionResponse = r.completionResponse;
        ret.context = r.context;
        ret.action = r.responseGuardrailResponse.action;
        ret.error = r.responseGuardrailResponse.error;
        ret.auditData = r.responseGuardrailResponse.auditData;
        ret.updatedMessagesForRetry = r.responseGuardrailResponse.updatedMessagesForRetry;
        if (r.trace != null) {
            trace.addObservation(r.trace);
        }
        this.fixupErrorIfNeeded(ret);
        return ret;
    }

    @Override
    public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        throw new Exception("Unreachable");
    }

    @Override
    public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.initOnce();
        PythonGuardrailServerAPI.Query q = new PythonGuardrailServerAPI.Query();
        q.context = context;
        q.embeddingQuery = query;
        q.bypassToken = this.bypassToken;
        PythonGuardrailServerAPI.Response r = this.serverAPI.processAsync(q).get();
        GuardrailRunner.EmbeddingQueryGuardrailResponse ret = new GuardrailRunner.EmbeddingQueryGuardrailResponse();
        ret.embeddingQuery = r.embeddingQuery;
        ret.context = r.context;
        ret.action = r.queryGuardrailResponse.action;
        ret.auditData = r.queryGuardrailResponse.auditData;
        ret.error = r.queryGuardrailResponse.error;
        if (r.trace != null) {
            trace.addObservation(r.trace);
        }
        this.fixupErrorIfNeeded(ret);
        return ret;
    }

    @Override
    public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.initOnce();
        PythonGuardrailServerAPI.Query q = new PythonGuardrailServerAPI.Query();
        q.context = context;
        q.imageGenerationQuery = query;
        q.bypassToken = this.bypassToken;
        PythonGuardrailServerAPI.Response r = this.serverAPI.processAsync(q).get();
        GuardrailRunner.ImageGenerationQueryGuardrailResponse ret = new GuardrailRunner.ImageGenerationQueryGuardrailResponse();
        ret.imageGenerationQuery = r.imageGenerationQuery;
        ret.context = r.context;
        ret.action = r.queryGuardrailResponse.action;
        ret.auditData = r.queryGuardrailResponse.auditData;
        ret.error = r.queryGuardrailResponse.error;
        if (r.trace != null) {
            trace.addObservation(r.trace);
        }
        this.fixupErrorIfNeeded(ret);
        return ret;
    }

    @Override
    public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
        return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
    }

    public CustomGuardrailRunner(DSSAuthCtx authCtx, String contextProjectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
        this.authCtx = authCtx;
        this.contextProjectKey = contextProjectKey;
        this.params = elt.getParamsCopyAs(CustomGuardrailParams.class);
        this.meta = (CustomGuardrailMeta)GuardrailsRegistry.getMeta(elt.type);
        this.elt = elt;
        this.bypassToken = bypassToken;
    }

    protected synchronized void initOnce() throws IOException {
        block7: {
            if (this.serverAPI == null) {
                try {
                    LoadedCustomGuardrail desc = this.meta.getLoadedDesc();
                    String envName = new CodeEnvSelector().selectForCustomPythonRecipe(desc.ownerPluginId);
                    String containerConfName = new ContainerExecConfigSelector().selectConfName_autoTXN(this.authCtx, this.contextProjectKey, this.params.containerExecSelection);
                    String clusterId = new ClusterSelector().selectForProject(this.authCtx, this.contextProjectKey).getClusterId();
                    IPluginsRegistryService pluginsService = (IPluginsRegistryService)SpringUtils.getBean(IPluginsRegistryService.class);
                    File pluginFolder = pluginsService.getActualPluginFolder(desc.ownerPluginId);
                    File codeFolder = new File(pluginFolder, desc.folderName);
                    CustomGuardrailsService customGuardrailsService = (CustomGuardrailsService)SpringUtils.getBean(CustomGuardrailsService.class);
                    String libFolder = customGuardrailsService.getPythonLibFolder(desc.getType());
                    String code = DKUFileUtils.readFileToStringUTF8((File)new File(codeFolder, "guardrail.py"));
                    JsonObject pluginConfig = new JsonObject();
                    boolean forceDevMode = DKUApp.getParams().getBoolParam("dku.llm.guardrails.custom.forceDevMode", false);
                    if (forceDevMode) {
                        String idForLogs = SecretKeyGenerator.generate((int)6);
                        File logBaseDir = DKUApp.getFile((String[])new String[]{"python-guardrails", this.contextProjectKey, desc.ownerPluginId + idForLogs, "dev-logs"});
                        PythonGuardrailServer grServer = new PythonGuardrailServer(this.authCtx, this.contextProjectKey, code, envName, containerConfName, desc.ownerPluginId, libFolder, this.params.config, pluginConfig, false);
                        try {
                            grServer.start(true, logBaseDir);
                        }
                        catch (Exception e) {
                            logger.warn((Object)"Kernel failed to start, closing it", (Throwable)e);
                            IOUtils.closeQuietly((Closeable)grServer, null);
                            throw e;
                        }
                        this.serverAPI = grServer;
                        break block7;
                    }
                    PythonGuardrailServerPool kernelPool = (PythonGuardrailServerPool)SpringUtils.getBean(PythonGuardrailServerPool.class);
                    this.serverAPI = kernelPool.getServerAPI(this.authCtx, this.contextProjectKey, this.elt, code, desc.ownerPluginId, libFolder, false, this.params.config, null, envName, containerConfName, clusterId);
                }
                catch (CustomPythonKernelException e) {
                    throw e;
                }
                catch (Exception e) {
                    throw new IOException("Failed to initialize Python Guardrail", e);
                }
            }
        }
    }

    public SmartLogTail getKernelLog() throws Exception {
        this.initOnce();
        boolean forceDevMode = DKUApp.getParams().getBoolParam("dku.llm.guardrails.custom.forceDevMode", false);
        if (forceDevMode) {
            assert (this.serverAPI instanceof PythonLLMServer);
            return ((PythonLLMServer)((Object)this.serverAPI)).getKernelLogTail();
        }
        SmartLogTail fakeSLT = new SmartLogTail();
        fakeSLT.appendLine("You need to switch to development mode to get logs. Contact support for more info.");
        return fakeSLT;
    }

    @Override
    public void close() {
        IOUtils.closeQuietly((Closeable)this.serverAPI, null);
    }

    @Override
    public void init() throws Exception {
    }

    private void fixupErrorIfNeeded(GuardrailRunner.QueryGuardrailResponse qr) {
        if (qr.error != null && qr.error.message != null && qr.error.detailedMessage == null) {
            qr.error.detailedMessage = qr.error.message;
        }
    }

    private void fixupErrorIfNeeded(GuardrailRunner.ResponseGuardrailResponse qr) {
        if (qr.error != null && qr.error.message != null && qr.error.detailedMessage == null) {
            qr.error.detailedMessage = qr.error.message;
        }
    }

    public static class CustomGuardrailParams
    implements GuardrailParams {
        public JsonObject config;
        public ContainerExecSelection containerExecSelection = new ContainerExecSelection();
    }
}

