/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.agents.tools.AbstractPythonAgentToolRunner;
import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.AgentToolsDAO;
import com.dataiku.dip.agents.tools.AgentToolsRegistry;
import com.dataiku.dip.agents.tools.PythonAgentToolServer;
import com.dataiku.dip.agents.tools.PythonAgentToolServerAPI;
import com.dataiku.dip.agents.tools.PythonAgentToolServerKernelDesc;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.kernel.DSSKernelUtils;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.llm.LLMRelatedPoolablePythonServerKernelPool;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.utils.SmartLogTail;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PythonAgentToolServerPool
extends LLMRelatedPoolablePythonServerKernelPool<PythonAgentToolServer, PythonAgentToolServerKernelDesc, PythonAgentToolServerKernelDesc> {
    @Autowired
    private AgentToolsDAO agentToolsDAO;
    @Autowired
    private TransactionService transactionService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.python.pool");

    public void invalidateKernels(AgentTool tool) {
        this.manager.clearKernels(kd -> tool.projectKey.equals(kd.projectKey) && tool.id.equals(kd.toolId), KernelPool.DeathReason.OUTDATED);
    }

    public boolean stopDevKernel(DSSAuthCtx authCtx, AgentTool tool) {
        PythonAgentToolServerKernelDesc kernelDesc = new PythonAgentToolServerKernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = tool.projectKey;
        kernelDesc.toolId = tool.id;
        kernelDesc.isDevKernel = true;
        Optional kernelID = this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc));
        if (kernelID.isEmpty()) {
            return false;
        }
        return this.manager.forceStopKernel((String)kernelID.get(), KernelPool.DeathReason.USER_REQUEST);
    }

    public static PythonAgentToolServerKernelDesc buildKernelDesc(DSSAuthCtx authCtx, AgentTool tool, String pyClazz, String code, String pluginId, String libFolder, boolean loadPythonLibs, JsonObject config, JsonObject pluginConfig, String envName, String containerConfName, String clusterId, boolean devKernel) {
        PythonAgentToolServerKernelDesc kernelDesc = new PythonAgentToolServerKernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = tool.projectKey;
        kernelDesc.toolId = tool.id;
        kernelDesc.toolVersionNumber = tool.versionTag != null ? tool.versionTag.versionNumber : 0L;
        kernelDesc.isDevKernel = devKernel;
        kernelDesc.cruContext = ComputeResourceUsageContext.forPythonAgentTool((AuthCtx)authCtx, (String)tool.projectKey, (String)tool.id);
        kernelDesc.pyClazz = pyClazz;
        kernelDesc.code = code;
        kernelDesc.pluginId = pluginId;
        kernelDesc.libFolder = libFolder;
        kernelDesc.loadPythonLibs = loadPythonLibs;
        kernelDesc.config = config;
        kernelDesc.pluginConfig = pluginConfig;
        kernelDesc.dkuProperties = tool.dkuProperties;
        kernelDesc.singleInstance = tool.singleInstance;
        kernelDesc.envName = envName;
        kernelDesc.containerConfName = containerConfName;
        kernelDesc.clusterId = clusterId;
        kernelDesc.poolKey = authCtx.getIdentifier() + "-" + tool.projectKey + "-" + tool.id + "-" + DigestUtils.sha1Hex((String)StringUtils.join((Object[])new String[]{JSON.json((Object)tool), pyClazz, code, pluginId, libFolder, "" + loadPythonLibs, JSON.json((Object)config), JSON.json((Object)pluginConfig), envName, containerConfName, clusterId, "" + devKernel}, (String)"__DKU__"));
        return kernelDesc;
    }

    public void wakeUp(DSSAuthCtx authCtx, String projectKey, AgentTool at) throws Exception {
        PythonAgentToolServerKernelDesc desc = null;
        AgentToolMeta meta = AgentToolsRegistry.getMeta(at.type);
        try (AgentToolRunner baseRunner = meta.buildRunner(authCtx, projectKey, at);){
            if (baseRunner instanceof AbstractPythonAgentToolRunner) {
                AbstractPythonAgentToolRunner runner = (AbstractPythonAgentToolRunner)baseRunner;
                desc = runner.buildKernelDesc();
            }
        }
        if (desc != null) {
            this.manager.handle(kernel -> CompletableFuture.completedFuture(null), (Object)desc, desc.poolKey, null);
        }
    }

    public PythonAgentToolServerAPI getServerAPI(final PythonAgentToolServerKernelDesc kernelDesc) {
        return new PythonAgentToolServerAPI(){

            @Override
            public CompletableFuture<AgentToolRunner.AgentToolOutput> runAsync(AgentToolRunner.AgentToolInput input) {
                return PythonAgentToolServerPool.this.manager.handle(kernel -> kernel.runAsync(input), (Object)kernelDesc, kernelDesc.poolKey, (Object)input);
            }

            @Override
            public CompletableFuture<AgentToolMeta.ToolDescriptor> getResultingDescriptorAsync(AgentTool tool) throws Exception {
                return PythonAgentToolServerPool.this.manager.handle(kernel -> kernel.getResultingDescriptorAsync(tool), (Object)kernelDesc, kernelDesc.poolKey, (Object)tool);
            }

            @Override
            public CompletableFuture<AgentToolMeta.ToolCallDescription> getToolCallDescriptionAsync(AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) {
                return PythonAgentToolServerPool.this.manager.handle(kernel -> kernel.getToolCallDescriptionAsync(tool, descriptor, input), (Object)kernelDesc, kernelDesc.poolKey, (Object)tool);
            }

            @Override
            public CompletableFuture<JsonObject> loadSampleQuery(AgentTool tool) {
                return PythonAgentToolServerPool.this.manager.handle(kernel -> kernel.loadSampleQuery(tool), (Object)kernelDesc, kernelDesc.poolKey, (Object)tool);
            }

            @Override
            public SmartLogTail getKernelLog() {
                try {
                    if (!kernelDesc.isDevKernel) {
                        throw new IllegalArgumentException("Production kernels can't return logs");
                    }
                    String kernelID = (String)PythonAgentToolServerPool.this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc)).orElseThrow(() -> new IllegalArgumentException("Dev kernel not found: " + kernelDesc.getDevKernelKey()));
                    return (SmartLogTail)PythonAgentToolServerPool.this.manager.getKernelLogs(kernelID).orElseThrow(() -> new IllegalArgumentException("Logs not found for dev kernel: " + kernelDesc.getDevKernelKey()));
                }
                catch (Exception e) {
                    SmartLogTail fakeSLT = new SmartLogTail();
                    fakeSLT.appendLine(e.getMessage());
                    return fakeSLT;
                }
            }

            @Override
            public void close() throws IOException {
            }
        };
    }

    public PythonAgentToolServerPool() {
        super("python-agent-tool");
        this.setManager(new KernelPool((KernelPool.KernelController)new LLMRelatedPoolablePythonServerKernelPool.LLMRelatedPoolablePythonServerKernelController("dku.agents.tools.python"){

            @Nonnull
            public PythonAgentToolServer createKernel(PythonAgentToolServerKernelDesc kernelDesc) {
                return new PythonAgentToolServer(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.toolId, kernelDesc.pyClazz, kernelDesc.code, kernelDesc.envName, kernelDesc.containerConfName, kernelDesc.pluginId, kernelDesc.libFolder, kernelDesc.config, kernelDesc.pluginConfig, kernelDesc.loadPythonLibs);
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(PythonAgentToolServer kernel, PythonAgentToolServerKernelDesc kernelDesc) {
                File logBaseDir = DKUApp.getFile((String[])new String[]{"agent-tools", kernelDesc.projectKey, kernelDesc.toolId, "logs"});
                return DKUCompletableFuture.runAsync(() -> {
                    DSSKernelUtils.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                    kernel.start(kernelDesc.isDevKernel, false, logBaseDir);
                }, (Executor)PythonAgentToolServerPool.this.executorService);
            }

            public Long getQueuedRequestTimeoutInNs() {
                return ApplicationConfigurator.getParams().getLongParam(this.propertiesPrefix + ".queuedRequestTimeoutInS", 1800L) * 1000L * 1000L * 1000L;
            }

            @Override
            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam(this.propertiesPrefix + ".maxKernels", Integer.valueOf(50));
            }

            public int getAutoscaleTimeWindowSeconds(PythonAgentToolServerKernelDesc kernelDesc) {
                if (kernelDesc.isDevKernel) {
                    return ApplicationConfigurator.getParams().getIntParam(this.propertiesPrefix + ".dev.autoscaleWindowSeconds", Integer.valueOf(120));
                }
                return ApplicationConfigurator.getParams().getIntParam(this.propertiesPrefix + ".autoscaleWindowSeconds", Integer.valueOf(600));
            }

            public int getMinimumRetentionTimeSeconds(PythonAgentToolServerKernelDesc kernelDesc) {
                if (kernelDesc.isDevKernel) {
                    return 0;
                }
                AgentTool tool = this.getAgentToolUnsafe(kernelDesc);
                if (tool != null && tool.minimumRetentionTimeSeconds != null && tool.minimumRetentionTimeSeconds >= 0) {
                    return tool.minimumRetentionTimeSeconds;
                }
                return ApplicationConfigurator.getParams().getIntParam(this.propertiesPrefix + ".minimumRetentionTimeSeconds", Integer.valueOf(1800));
            }

            public Integer getMaxKernelCount(PythonAgentToolServerKernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.dkuProperties);
                return DSSKernelUtils.getMaxKernelCount(localParams, this.propertiesPrefix + ".maxKernelPerAgentTool", this.propertiesPrefix + ".maxKernelProportionPerAgentTool", kernelDesc.singleInstance, this.getGlobalMaxKernelCount());
            }

            public int getHardMaxParallelRequests(PythonAgentToolServerKernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.dkuProperties);
                return DSSKernelUtils.getIntParamWithFallback(localParams, this.propertiesPrefix + ".hardMaxRequestsPerKernel", 16);
            }

            public int getSoftMaxParallelRequests(PythonAgentToolServerKernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.dkuProperties);
                return DSSKernelUtils.getIntParamWithFallback(localParams, this.propertiesPrefix + ".softMaxRequestsPerKernel", 16);
            }

            public boolean isOutdated(PythonAgentToolServerKernelDesc kernelDesc) {
                AgentTool tool = this.getAgentToolUnsafe(kernelDesc);
                return tool == null || tool.versionTag != null && tool.versionTag.versionNumber != kernelDesc.toolVersionNumber;
            }

            @Nullable
            private AgentTool getAgentToolUnsafe(PythonAgentToolServerKernelDesc kernelDesc) {
                AgentTool agentTool;
                block8: {
                    Transaction t = PythonAgentToolServerPool.this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);
                    try {
                        agentTool = (AgentTool)PythonAgentToolServerPool.this.agentToolsDAO.getOrNullUnsafe(kernelDesc.projectKey, kernelDesc.toolId);
                        if (t == null) break block8;
                    }
                    catch (Throwable throwable) {
                        try {
                            if (t != null) {
                                try {
                                    t.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        catch (IOException e) {
                            logger.error((Object)String.format("Failed to load agent tool %s.%s", kernelDesc.projectKey, kernelDesc.toolId), (Throwable)e);
                            return null;
                        }
                    }
                    t.close();
                }
                return agentTool;
            }

            public SmartLogTail getKernelLog(PythonAgentToolServer kernel) {
                return kernel.getKernelLog();
            }
        }, "python-agent-tool", logger));
    }
}

