/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.retrieval;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.containers.exec.WorkloadType;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.kernel.DSSKernelUtils;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.langchain.DevKernelDesc;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.retrieval.LangchainBasedRAGClient;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledgeUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.dataiku.dss.shadelib.org.apache.commons.codec.binary.Hex;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import java.io.IOException;
import java.security.MessageDigest;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.log4j.NDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class RAGKernelPool {
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private TransactionService transactionService;
    private final KernelPool<LangchainBasedRAGClient, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("rag-kernel-pool-startstop-%d").build());
    @Autowired
    private PubSubService pubSubService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.ragpool");

    private static String safeGetSavedModelId(LLMStructuredRef llmRef) {
        if (llmRef != null) {
            return llmRef.savedModelSmartId;
        }
        return "";
    }

    private static String getGroupKey(RetrievableKnowledge rk, RAGLLMSettings settings, String containerConfName, String clusterId, String envName, String user, String savedModelSmartId, boolean isDevKernel) {
        MessageDigest digest = DigestUtils.getSha256Digest();
        JSON.updateDigest((MessageDigest)digest, (Object)rk);
        JSON.updateDigest((MessageDigest)digest, (Object)settings);
        JSON.updateDigest((MessageDigest)digest, (Object)containerConfName);
        JSON.updateDigest((MessageDigest)digest, (Object)clusterId);
        JSON.updateDigest((MessageDigest)digest, (Object)envName);
        JSON.updateDigest((MessageDigest)digest, (Object)user);
        JSON.updateDigest((MessageDigest)digest, (Object)savedModelSmartId);
        JSON.updateDigest((MessageDigest)digest, (Object)isDevKernel);
        return "rag-" + Hex.encodeHexString((byte[])digest.digest()).substring(0, 10);
    }

    public LLMClient getClient(AuthCtx authCtx, String projectKey, LLMStructuredRef llmRef, RetrievableKnowledge rk, SavedModel.SavedModelInlineVersion smiv, boolean devKernel) throws Exception {
        final KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.cruContext = ComputeResourceUsageContext.forRAGKernel((AuthCtx)authCtx, (String)projectKey, (String)llmRef.id);
        try {
            ContainerExecSelection containerExecSelection = rk.containerExecSelection.containerMode == ContainerExecSelection.ContainerExecMode.INHERIT ? ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.defaultRetrievableKnowledgeContainerExecSelection : rk.containerExecSelection;
            kernelDesc.containerConfName = new ContainerExecConfigSelector().selectName_autoTXN(projectKey, containerExecSelection, WorkloadType.USER_CODE);
            kernelDesc.clusterId = new ClusterSelector().selectForProject(authCtx, projectKey).getClusterId();
        }
        catch (DKUSecurityException | IOException e) {
            throw new RuntimeException(e);
        }
        kernelDesc.projectKey = projectKey;
        kernelDesc.llmRef = llmRef;
        kernelDesc.authCtx = authCtx;
        kernelDesc.settings = smiv.ragllmSettings;
        kernelDesc.defaultCompletionSettings = smiv.ragllmSettings.completionSettings;
        kernelDesc.rk = rk;
        kernelDesc.savedModelVersionVersionNumber = smiv.versionTag != null ? smiv.versionTag.versionNumber : 0L;
        kernelDesc.isDevKernel = devKernel;
        kernelDesc.envSelection = rk.envSelection;
        kernelDesc.hash = RAGKernelPool.getGroupKey(kernelDesc.rk, kernelDesc.settings, kernelDesc.containerConfName, kernelDesc.clusterId, kernelDesc.envSelection.envName, authCtx.getIdentifier(), RAGKernelPool.safeGetSavedModelId(kernelDesc.llmRef), kernelDesc.isDevKernel);
        kernelDesc.jobContext = JobContext.getCurrentJobContext();
        kernelDesc.rkVersion = RetrievableKnowledgeUtils.getCurrentVersionUnsafe(rk);
        final LLMClient embeddingLLMClient = LLMClientFactory.get(authCtx, projectKey, LLMStructuredRef.decodeId(rk.embeddingLLMId));
        final LLMClient augmentedLLMClient = LLMClientFactory.get(authCtx, projectKey, LLMStructuredRef.decodeId(kernelDesc.settings.llmId));
        return new LLMClient(){
            private final DKUCompletableFuture.FutureCancellationTracker futureCancellationTracker = new DKUCompletableFuture.FutureCancellationTracker();

            @Override
            public void close() {
                this.futureCancellationTracker.cancelAll("RAG client closed");
            }

            @Override
            public boolean supportNativeBatch() {
                return false;
            }

            @Override
            public boolean requiresCostLimiting() {
                return false;
            }

            @Override
            public String getProviderId() {
                return null;
            }

            @Override
            public AbstractLLMConnection getConnection() {
                return null;
            }

            @Override
            public int getMaxParallelism() {
                return Math.min(embeddingLLMClient.getMaxParallelism(), augmentedLLMClient.getMaxParallelism());
            }

            CompletableFuture<LLMClient.SimpleCompletionResponse> complete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
                return this.futureCancellationTracker.track(() -> RAGKernelPool.this.manager.handle(kernel -> kernel.asyncComplete(query, settings), (Object)kernelDesc, kernelDesc2.hash, (Object)query));
            }

            @Override
            public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) {
                for (LLMClient.SingleCompletionQuery q2 : queries) {
                    Optional<LLMClient.ChatMessage> userPrompt = q2.messages.stream().filter(m -> m.role.equals("user")).findFirst();
                    if (!userPrompt.isEmpty() && !userPrompt.get().getText().isEmpty()) continue;
                    throw new IllegalArgumentException("A RAG prompt requires a User input message.");
                }
                List futures = queries.stream().map(q -> this.complete((LLMClient.SingleCompletionQuery)q, settings)).collect(Collectors.toList());
                return DKUCompletableFuture.collectResponsesNoException(futures);
            }

            @Override
            public boolean supportsStream() {
                RAGLLMSettings.GuardrailsSettings gs = kernelDesc.settings.ragSpecificGuardrails;
                return !gs.faithfulnessSettings.enabled && !gs.relevancySettings.enabled && !gs.multimodalFaithfulnessSettings.enabled && !gs.multimodalRelevancySettings.enabled && kernelDesc.settings.outputFormat != RAGLLMSettings.RAGOutputFormat.JSON;
            }

            @Override
            public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
                CompletableFuture cf = this.futureCancellationTracker.track(() -> RAGKernelPool.this.manager.handle(kernel -> kernel.asyncStreamComplete(query, settings, consumer), (Object)kernelDesc, kernelDesc2.hash, (Object)query));
                Integer nChunks = (Integer)DKUCompletableFuture.collectResponse((CompletableFuture)cf);
                logger.infoV("Fully received streamed completion response with %s chunks", new Object[]{nChunks});
            }

            @Override
            public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
                throw new IllegalArgumentException("Embeddings not supported on this LLM");
            }

            @Override
            public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
                throw new IllegalArgumentException("Reranking not supported on this LLM");
            }

            @Override
            public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
                return null;
            }

            @Override
            public EnrichedLLMStructuredRef getEnrichedRef() throws Exception {
                return ((LLMRefEnricherService)SpringUtils.getBean(LLMRefEnricherService.class)).getEnrichedLLMRefFromRetrievalAugmentedLLM(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.llmRef);
            }

            @Override
            public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
                return LangchainBasedRAGClient.getFormattedPrompt(kernelDesc.settings, chatMessages);
            }

            @Override
            public SmartLogTail getKernelLog() {
                try {
                    if (!kernelDesc.isDevKernel) {
                        throw new IllegalArgumentException("Production kernels can't return logs");
                    }
                    String kernelID = (String)RAGKernelPool.this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc)).orElseThrow(() -> new IllegalArgumentException("Dev kernel not found: " + kernelDesc.getDevKernelKey()));
                    return (SmartLogTail)RAGKernelPool.this.manager.getKernelLogs(kernelID).orElseThrow(() -> new IllegalArgumentException("Logs not found for dev kernel: " + kernelDesc.getDevKernelKey()));
                }
                catch (Exception e) {
                    SmartLogTail fakeSLT = new SmartLogTail();
                    fakeSLT.appendLine(e.getMessage());
                    return fakeSLT;
                }
            }
        };
    }

    public RAGKernelPool() {
        this.manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<LangchainBasedRAGClient, KernelDesc, KernelDesc>(){

            @Nonnull
            public LangchainBasedRAGClient createKernel(KernelDesc kernelDesc) {
                try {
                    return new LangchainBasedRAGClient(kernelDesc.authCtx, kernelDesc.llmRef, kernelDesc.projectKey, kernelDesc.rk, kernelDesc.settings, kernelDesc.defaultCompletionSettings, kernelDesc.envSelection, kernelDesc.containerConfName, kernelDesc.isDevKernel);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(LangchainBasedRAGClient kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(() -> {
                    NDC.push((String)("start-rag-kernel: " + kernel.getKernelId()));
                    try {
                        DSSKernelUtils.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                        kernel.start();
                    }
                    finally {
                        NDC.pop();
                    }
                }, (Executor)RAGKernelPool.this.executorService);
            }

            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.maxKernels", Integer.valueOf(50));
            }

            public int getAutoscaleTimeWindowSeconds(KernelDesc kernelDesc) {
                if (kernelDesc.isDevKernel) {
                    return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.dev.autoscaleWindowSeconds", Integer.valueOf(120));
                }
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.autoscaleWindowSeconds", Integer.valueOf(600));
            }

            public int getMinimumRetentionTimeSeconds(KernelDesc kernelDesc) {
                if (kernelDesc.isDevKernel) {
                    return 0;
                }
                if (kernelDesc.settings != null && kernelDesc.settings.minimumRetentionTimeSeconds != null && kernelDesc.settings.minimumRetentionTimeSeconds >= 0) {
                    return kernelDesc.settings.minimumRetentionTimeSeconds;
                }
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.minimumRetentionTimeSeconds", Integer.valueOf(1800));
            }

            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.hardMaxRequestsPerKernel", Integer.valueOf(128));
            }

            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.rag.softMaxRequestsPerKernel", Integer.valueOf(32));
            }

            @Nonnull
            public CompletableFuture<Void> killKernel(LangchainBasedRAGClient kernel) {
                return CompletableFuture.runAsync(() -> {
                    NDC.push((String)("stop-rag-kernel: " + kernel.getKernelId()));
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                    finally {
                        NDC.pop();
                    }
                }, RAGKernelPool.this.executorService);
            }

            /*
             * Enabled aggressive block sorting
             * Enabled unnecessary exception pruning
             * Enabled aggressive exception aggregation
             */
            public boolean isOutdated(KernelDesc kernelDesc) {
                try {
                    String version = RetrievableKnowledgeUtils.getCurrentVersionUnsafe(kernelDesc.rk);
                    if (!Objects.equals(version, kernelDesc.rkVersion)) {
                        return true;
                    }
                    if (kernelDesc.llmRef.savedModelSmartId == null) return false;
                    if (kernelDesc.llmRef.savedModelVersionId == null) {
                        return false;
                    }
                    try (Transaction t = RAGKernelPool.this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
                        AnyLoc loc = AnyLoc.resolveSmart(kernelDesc.projectKey, kernelDesc.llmRef.savedModelSmartId);
                        SavedModel sm = (SavedModel)RAGKernelPool.this.savedModelsDAO.getOrNullUnsafe(loc.getProjectKey(), loc.getId());
                        if (sm == null) {
                            boolean bl2 = true;
                            return bl2;
                        }
                        Optional<SavedModel.SavedModelInlineVersion> smiv = sm.getVersion(kernelDesc.llmRef.savedModelVersionId);
                        boolean bl = smiv.isEmpty() || smiv.get().versionTag != null && smiv.get().versionTag.versionNumber != kernelDesc.savedModelVersionVersionNumber;
                        return bl;
                    }
                }
                catch (IOException e) {
                    logger.error((Object)String.format("Failed to check if KB %s is outdated", kernelDesc.rk.getFullId()), (Throwable)e);
                    return false;
                }
            }

            public boolean isAlive(LangchainBasedRAGClient kernel) {
                return kernel.isAlive();
            }

            public String getKernelId(LangchainBasedRAGClient kernel) {
                return kernel.getKernelId();
            }

            public SmartLogTail getKernelLog(LangchainBasedRAGClient kernel) {
                return kernel.getKernelLog();
            }
        }, "rag", logger);
    }

    public void clearKernels(RetrievableKnowledge rk) {
        try {
            logger.info((Object)("Knowledge base " + rk.projectKey + "." + rk.id + " data was cleared. Sending scale down signal to kernels using this knowledge base"));
            this.manager.clearKernels(kernelDesc -> rk.projectKey.equals(kernelDesc.rk.projectKey) && rk.id.equals(kernelDesc.rk.id), KernelPool.DeathReason.OUTDATED);
            logger.info((Object)"Done notifying kernels to scale down");
        }
        catch (Exception e) {
            logger.error((Object)("Error while clearing kernels for knowledge base " + rk.projectKey + "." + rk.id), (Throwable)e);
        }
    }

    public boolean stopDevKernel(DSSAuthCtx authCtx, String projectKey, String kbId, LLMStructuredRef llmRef) throws IOException {
        KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = projectKey;
        kernelDesc.rk = new RetrievableKnowledge();
        kernelDesc.rk.id = kbId;
        kernelDesc.isDevKernel = true;
        kernelDesc.llmRef = llmRef;
        Optional kernelID = this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc));
        if (kernelID.isEmpty()) {
            return false;
        }
        return this.manager.forceStopKernel((String)kernelID.get(), KernelPool.DeathReason.USER_REQUEST);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    static class KernelDesc
    extends DevKernelDesc {
        AuthCtx authCtx;
        LLMStructuredRef llmRef;
        String projectKey;
        RetrievableKnowledge rk;
        String rkVersion;
        RAGLLMSettings settings;
        LLMClient.CompletionSettings defaultCompletionSettings;
        long savedModelVersionVersionNumber;
        CodeEnvSelection envSelection;
        String containerConfName;
        String clusterId;
        String hash;
        ComputeResourceUsageContext cruContext;
        JobContext jobContext;

        KernelDesc() {
        }

        @Override
        public String getDevKernelKey() {
            assert (this.isDevKernel);
            return this.authCtx.getIdentifier() + "-" + this.projectKey + "-" + this.rk.id + "-" + RAGKernelPool.safeGetSavedModelId(this.llmRef);
        }
    }
}

