/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.langchain;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.llm.kernel.KernelPool;
import com.dataiku.dip.llm.kernel.KernelPoolThreadFactory;
import com.dataiku.dip.llm.kernel.KernelScalingStrategyBuilder;
import com.dataiku.dip.llm.langchain.PythonLLMServer;
import com.dataiku.dip.llm.langchain.PythonLLMServerAPI;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.stereotype.Service;

@Service
public class PythonLLMServerKernelPool {
    private final KernelPool<PythonLLMServer, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("python-llm-poolmgr-%d").build());
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.python.pool");

    public void invalidateKernels(String projectKey, String savedModelId) {
        this.manager.clearKernels(kd -> projectKey.equals(kd.projectKey) && savedModelId.equals(kd.savedModelId), KernelPool.DeathReason.OUTDATED);
    }

    public PythonLLMServerAPI getServerAPI(DSSAuthCtx authCtx, String projectKey, String savedModelId, String savedModelVersionId, String pyClazz, String code, JsonObject config, JsonObject pluginConfig, String envName, String containerConfName, String pluginId, String libFolder, boolean loadPythonLibs, String clusterId, SavedModel.AgentSettings settings) {
        final KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = projectKey;
        kernelDesc.savedModelId = savedModelId;
        kernelDesc.cruContext = ComputeResourceUsageContext.forPythonTool((String)projectKey);
        kernelDesc.savedModelVersionId = savedModelVersionId;
        kernelDesc.pyClazz = pyClazz;
        kernelDesc.code = code;
        kernelDesc.config = config;
        kernelDesc.pluginConfig = pluginConfig;
        kernelDesc.envName = envName;
        kernelDesc.containerConfName = containerConfName;
        kernelDesc.pluginId = pluginId;
        kernelDesc.libFolder = libFolder;
        kernelDesc.loadPythonLibs = loadPythonLibs;
        kernelDesc.clusterId = clusterId;
        kernelDesc.settings = settings;
        kernelDesc.poolKey = authCtx.getIdentifier() + "-" + projectKey + "-" + savedModelId + "-" + savedModelVersionId + "-" + DigestUtils.sha1Hex((String)StringUtils.join((Object[])new String[]{pyClazz, code, JSON.json((Object)config), JSON.json((Object)pluginConfig), envName, containerConfName, pluginId, libFolder, clusterId, JSON.json((Object)settings)}, (String)"__DKU__"));
        return new PythonLLMServerAPI(){

            @Override
            public CompletableFuture<LLMClient.SimpleCompletionResponse> processAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
                return PythonLLMServerKernelPool.this.manager.handle(kernel -> kernel.processAsync(query, settings), kernelDesc, kernelDesc.poolKey, query);
            }

            @Override
            public CompletableFuture<Integer> streamProcessAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
                return PythonLLMServerKernelPool.this.manager.handle(kernel -> kernel.streamProcessAsync(query, settings, consumer), kernelDesc, kernelDesc.poolKey, query);
            }

            @Override
            public void close() throws IOException {
            }
        };
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    public PythonLLMServerKernelPool() {
        this.manager = new KernelPool<PythonLLMServer, KernelDesc, KernelDesc>(new KernelPool.KernelController<PythonLLMServer, KernelDesc, KernelDesc>(){

            @Override
            public PythonLLMServer createKernel(KernelDesc kernelDesc) {
                File logBaseDir = DKUApp.getFile((String[])new String[]{"saved_models", kernelDesc.projectKey, kernelDesc.savedModelId, "versions", kernelDesc.savedModelVersionId, "logs"});
                return new PythonLLMServer(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.savedModelId, kernelDesc.savedModelVersionId, kernelDesc.pyClazz, kernelDesc.code, kernelDesc.envName, kernelDesc.containerConfName, kernelDesc.pluginId, kernelDesc.libFolder, logBaseDir, kernelDesc.config, kernelDesc.pluginConfig, false, kernelDesc.loadPythonLibs);
            }

            @Override
            public CompletableFuture<Void> startKernel(PythonLLMServer kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(() -> {
                    this.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                    kernel.start();
                }, (Executor)PythonLLMServerKernelPool.this.executorService);
            }

            @Override
            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.python.maxKernels", Integer.valueOf(50));
            }

            @Override
            public int getAutoscaleTimeWindowSeconds() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.python.autoscaleWindowSeconds", Integer.valueOf(600));
            }

            @Override
            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return kernelDesc.settings.maxParallelRequestsPerProcess;
            }

            @Override
            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return kernelDesc.settings.maxParallelRequestsPerProcess;
            }

            @Override
            public CompletableFuture<Void> killKernel(PythonLLMServer kernel) {
                return CompletableFuture.runAsync(() -> {
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                }, PythonLLMServerKernelPool.this.executorService);
            }

            @Override
            public boolean isAlive(PythonLLMServer kernel) {
                return kernel.isAlive();
            }

            @Override
            public String getKernelId(PythonLLMServer kernel) {
                return kernel.getKernelId();
            }
        }, new KernelScalingStrategyBuilder(), new KernelPoolThreadFactory("python-llm"), logger);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    static class KernelDesc {
        DSSAuthCtx authCtx;
        String projectKey;
        String savedModelId;
        String savedModelVersionId;
        String pyClazz;
        String code;
        JsonObject config;
        JsonObject pluginConfig;
        String envName;
        String containerConfName;
        String pluginId;
        String libFolder;
        boolean loadPythonLibs;
        String clusterId;
        String poolKey;
        SavedModel.AgentSettings settings;
        ComputeResourceUsageContext cruContext;
        JobContext jobContext;

        KernelDesc() {
        }
    }
}

