/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.pii;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvResolutionService;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.containers.exec.WorkloadType;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.kernel.KernelPool;
import com.dataiku.dip.llm.kernel.KernelPoolThreadFactory;
import com.dataiku.dip.llm.kernel.KernelScalingStrategyBuilder;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.pii.PIIClient;
import com.dataiku.dip.llm.pii.PresidioBasedPIIHandlingServer;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.binary.Hex;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.security.MessageDigest;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.log4j.NDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PIIKernelPool {
    @Autowired
    private CodeEnvResolutionService codeEnvResolutionService;
    private final KernelPool<PresidioBasedPIIHandlingServer, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("pii-kernel-pool-startstop-%d").build());
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.piipool");

    private static String getGroupKey(PresidioBasedPIIHandlingServer.PresidioBasedPIIHandlingSettings settings, String containerConfName, String clusterId, String envName) {
        MessageDigest digest = DigestUtils.getSha256Digest();
        JSON.updateDigest((MessageDigest)digest, (Object)settings);
        JSON.updateDigest((MessageDigest)digest, (Object)containerConfName);
        JSON.updateDigest((MessageDigest)digest, (Object)clusterId);
        JSON.updateDigest((MessageDigest)digest, (Object)envName);
        return "pii-" + Hex.encodeHexString((byte[])digest.digest()).substring(0, 10);
    }

    private KernelDesc buildKernelDesc(AuthCtx authCtx, String projectKey, PresidioBasedPIIHandlingServer.PresidioBasedPIIHandlingSettings settings) throws IOException {
        KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.cruContext = ComputeResourceUsageContext.forPIIModel((String)projectKey);
        kernelDesc.projectKey = projectKey;
        kernelDesc.settings = settings;
        kernelDesc.authCtx = authCtx;
        try {
            ContainerExecSelection containerExecSelection = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.presidioBasedPIIDetectionContainerExecSelection;
            kernelDesc.containerConfName = new ContainerExecConfigSelector().selectName_autoTXN(projectKey, containerExecSelection, WorkloadType.USER_CODE);
            kernelDesc.clusterId = new ClusterSelector().selectForProject(authCtx, projectKey).getClusterId();
        }
        catch (DKUSecurityException | IOException e) {
            throw new RuntimeException(e);
        }
        kernelDesc.envName = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.getPresidioBasedPIIDetectionCodeEnv();
        this.codeEnvResolutionService.checkEnvExists(CodeEnvModel.EnvLang.PYTHON, kernelDesc.envName);
        kernelDesc.hash = PIIKernelPool.getGroupKey(kernelDesc.settings, kernelDesc.containerConfName, kernelDesc.clusterId, kernelDesc.envName);
        kernelDesc.jobContext = JobContext.getCurrentJobContext();
        return kernelDesc;
    }

    public PIIClient getClient(AuthCtx authCtx, String projectKey, PresidioBasedPIIHandlingServer.PresidioBasedPIIHandlingSettings settings) throws IOException {
        final KernelDesc kernelDesc = this.buildKernelDesc(authCtx, projectKey, settings);
        return new PIIClient(){

            @Override
            public CompletableFuture<PIIClient.CompletionQueryPIIDetectionResponse> processAsync(LLMClient.SingleCompletionQuery completionQuery) {
                return PIIKernelPool.this.manager.handle(kernel -> kernel.processAsync(completionQuery), kernelDesc, kernelDesc.hash, completionQuery);
            }

            @Override
            public CompletableFuture<PIIClient.CompletionResponsePIIDetectionResponse> processAsync(LLMClient.SimpleCompletionResponseOrError completionResponse) {
                return PIIKernelPool.this.manager.handle(kernel -> kernel.processAsync(completionResponse), kernelDesc, kernelDesc.hash, completionResponse);
            }

            @Override
            public CompletableFuture<PIIClient.EmbeddingQueryPIIDetectionResponse> processAsync(LLMClient.EmbeddingQuery embeddingQuery) {
                return PIIKernelPool.this.manager.handle(kernel -> kernel.processAsync(embeddingQuery), kernelDesc, kernelDesc.hash, embeddingQuery);
            }

            @Override
            public CompletableFuture<PIIClient.ImageGenerationQueryPIIDetectionResponse> processAsync(LLMClient.ImageGenerationQuery q) {
                return PIIKernelPool.this.manager.handle(kernel -> kernel.processAsync(q), kernelDesc, kernelDesc.hash, q);
            }
        };
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    public PIIKernelPool() {
        this.manager = new KernelPool<PresidioBasedPIIHandlingServer, KernelDesc, KernelDesc>(new KernelPool.KernelController<PresidioBasedPIIHandlingServer, KernelDesc, KernelDesc>(){

            @Override
            public PresidioBasedPIIHandlingServer createKernel(KernelDesc kernelDesc) {
                return new PresidioBasedPIIHandlingServer(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.settings, kernelDesc.envName, kernelDesc.containerConfName);
            }

            @Override
            public CompletableFuture<Void> startKernel(PresidioBasedPIIHandlingServer kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(() -> {
                    NDC.push((String)("start-pii-kernel: " + kernel.getKernelId()));
                    try {
                        this.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                        kernel.start();
                    }
                    finally {
                        NDC.pop();
                    }
                }, (Executor)PIIKernelPool.this.executorService);
            }

            @Override
            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.pii.maxKernels", Integer.valueOf(10));
            }

            @Override
            public int getAutoscaleTimeWindowSeconds() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.pii.autoscaleWindowSeconds", Integer.valueOf(600));
            }

            @Override
            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.pii.hardMaxRequestsPerKernel", Integer.valueOf(128));
            }

            @Override
            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.pii.softMaxRequestsPerKernel", Integer.valueOf(32));
            }

            @Override
            public CompletableFuture<Void> killKernel(PresidioBasedPIIHandlingServer kernel) {
                return CompletableFuture.runAsync(() -> {
                    NDC.push((String)("stop-pii-kernel: " + kernel.getKernelId()));
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                    finally {
                        NDC.pop();
                    }
                }, PIIKernelPool.this.executorService);
            }

            @Override
            public boolean isAlive(PresidioBasedPIIHandlingServer kernel) {
                return kernel.isAlive();
            }

            @Override
            public String getKernelId(PresidioBasedPIIHandlingServer kernel) {
                return kernel.getKernelId();
            }
        }, new KernelScalingStrategyBuilder(), new KernelPoolThreadFactory("pii"), logger);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    static class KernelDesc {
        String containerConfName;
        PresidioBasedPIIHandlingServer.PresidioBasedPIIHandlingSettings settings;
        String hash;
        AuthCtx authCtx;
        String envName;
        String projectKey;
        String clusterId;
        ComputeResourceUsageContext cruContext;
        JobContext jobContext;

        KernelDesc() {
        }
    }
}

