/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.evaluation;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.model.core.CustomMetricResult;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.io.JavaBlockLink;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernelFactory;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.kernel.KernelPoolThreadFactory;
import com.dataiku.dip.kernel.KernelScalingStrategyBuilder;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipeParams;
import com.dataiku.dip.recipes.nlp.llm_evaluation.LLMEvaluationRecipePayloadParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import org.apache.log4j.NDC;
import org.springframework.stereotype.Service;

@Service
public class TestCustomMetricKernelPool {
    private static final int KERNEL_KEEPALIVE_MINUTES = 5;
    private static final String LLM_EVALUATION_PYTHON_PACKAGE = "dataiku.llm.evaluation.server";
    private final KernelPool<SimplePythonKernel, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("test-custom-metric-kernel-pool-startstop-%d").build());
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.test-custom-metricpool");

    private static String getGroupKey(String recipeId, String userId, String envName) {
        return "test-custom-metric-" + recipeId + "-" + userId + "-" + envName;
    }

    public CustomMetricResult testCustomMetric(AuthCtx authCtx, SerializedRecipe recipe, LLMEvaluationRecipePayloadParams recipeDesc, int metricIndex, Dataset inputDataset, int timeoutInMinutes) throws Exception {
        authCtx.failIfNoSafeCode("You do not have the required permission to run code");
        LLMEvaluationRecipeParams recipeParams = recipe.getParamsAs(LLMEvaluationRecipeParams.class);
        KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = recipe.getProjectKey();
        kernelDesc.codeEnvName = new CodeEnvSelector().selectForPythonRecipe(recipe.getProjectKey(), recipeParams.getCodeEnvSelection());
        kernelDesc.containerExecConfig = new ContainerExecConfigSelector().select_autoTXN(authCtx, recipe.getProjectKey(), recipeParams.getContainerSelection());
        kernelDesc.hash = TestCustomMetricKernelPool.getGroupKey(recipe.getId(), authCtx.getIdentifier(), kernelDesc.codeEnvName);
        String inputDatasetSmartName = inputDataset.getSmartName(recipe.getProjectKey());
        TestCustomMetricCommand command = new TestCustomMetricCommand(recipeDesc, inputDatasetSmartName, metricIndex);
        CompletableFuture future = this.manager.handle(kernel -> CompletableFuture.supplyAsync(() -> {
            try {
                JavaBlockLink link = kernel.getLink();
                link.sendRequest((Object)command);
                return (CustomMetricResult)link.receiveJsonResponse(CustomMetricResult.class);
            }
            catch (Exception e) {
                throw new RuntimeException("Error while running python kernel for custom metric", e);
            }
        }).orTimeout(timeoutInMinutes, TimeUnit.MINUTES), (Object)kernelDesc, kernelDesc.hash, (Object)command);
        return (CustomMetricResult)DKUCompletableFuture.collectResponse((CompletableFuture)future);
    }

    public TestCustomMetricKernelPool() {
        this.manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<SimplePythonKernel, KernelDesc, KernelDesc>(){

            @Nonnull
            public SimplePythonKernel createKernel(KernelDesc kernelDesc) {
                try {
                    return SimplePythonKernelFactory.prepareKernel(kernelDesc.authCtx, kernelDesc.projectKey, GeneralSettingsDAO.CGrouppableProcessType.ML_RECIPE, kernelDesc.codeEnvName, TestCustomMetricKernelPool.LLM_EVALUATION_PYTHON_PACKAGE, true, false, null, kernelDesc.containerExecConfig, "test-custom-metric-" + SecretKeyGenerator.generateSmall(), "test-custom-metric-", false, true, null);
                }
                catch (Exception e) {
                    throw new RuntimeException("Error while creating python kernel for custom metric", e);
                }
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(SimplePythonKernel kernel, KernelDesc kernelDesc) {
                return CompletableFuture.runAsync(() -> {
                    NDC.push((String)("start-test-custom-metric-kernel: " + kernel.getId()));
                    try {
                        kernel.start();
                    }
                    catch (Exception e) {
                        throw new RuntimeException("Error while starting python kernel for custom metric", e);
                    }
                    finally {
                        NDC.pop();
                    }
                }, TestCustomMetricKernelPool.this.executorService);
            }

            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.eval.testMetric.maxKernels", Integer.valueOf(5));
            }

            public int getAutoscaleTimeWindowSeconds() {
                return 300;
            }

            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return 1;
            }

            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return 1;
            }

            @Nonnull
            public CompletableFuture<Void> killKernel(SimplePythonKernel kernel) {
                return CompletableFuture.runAsync(() -> {
                    NDC.push((String)("stop-test-custom-metric-kernel: " + kernel.getId()));
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                    finally {
                        NDC.pop();
                    }
                }, TestCustomMetricKernelPool.this.executorService);
            }

            public boolean isAlive(SimplePythonKernel kernel) {
                return kernel.isAlive();
            }

            public String getKernelId(SimplePythonKernel kernel) {
                return kernel.getId();
            }

            public boolean killKernelOnRequestFailure() {
                return true;
            }
        }, new KernelScalingStrategyBuilder(), new KernelPoolThreadFactory("test-custom-metric"), logger);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    static class KernelDesc {
        AuthCtx authCtx;
        String projectKey;
        String codeEnvName;
        ContainerExecRuntimeConfig containerExecConfig;
        String hash;

        KernelDesc() {
        }
    }

    static class TestCustomMetricCommand {
        public LLMEvaluationRecipePayloadParams recipeDesc;
        public String inputDatasetSmartName;
        public int indexOfMetricToCompute;

        TestCustomMetricCommand(LLMEvaluationRecipePayloadParams recipeDesc, String inputDatasetSmartName, int indexOfMetricToCompute) {
            this.recipeDesc = recipeDesc;
            this.inputDatasetSmartName = inputDatasetSmartName;
            this.indexOfMetricToCompute = indexOfMetricToCompute;
        }
    }
}

