/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.rag_embedding;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.AbstractPythonRecipeRunner;
import com.dataiku.dip.dataflow.exec.CodeBasedRecipeDatasetInfoHelper;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledgeDAO;
import com.dataiku.dip.recipes.InitializableAbortableRecipeRunner;
import com.dataiku.dip.recipes.nlp.common.EmbeddingRecipePayloadBaseParams;
import com.dataiku.dip.recipes.nlp.common.EmbeddingRecipeRunnerBase;
import com.dataiku.dip.recipes.nlp.rag_embedding.RAGEmbeddingRecipePayloadParams;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class RAGEmbeddingRecipeRunner
extends EmbeddingRecipeRunnerBase {
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    protected RetrievableKnowledgeDAO retrievableKnowledgeDAO;
    @Autowired
    protected ComputeResourceUsageReportingService cruReportingService;
    protected RAGEmbeddingRecipePayloadParams desc;
    protected FullModelId outputFMI;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.rag_embedding");

    public RAGEmbeddingRecipeRunner(JobActivity activity) {
        this.activity = activity;
        this.recipe = ((RecipeRunnableSubgraph)activity.getSubgraph()).getRecipe();
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (RAGEmbeddingRecipePayloadParams)JSON.parse((String)payload, RAGEmbeddingRecipePayloadParams.class);
    }

    @Override
    public EmbeddingRecipePayloadBaseParams parseRecipePayload(String payload) {
        return (EmbeddingRecipePayloadBaseParams)JSON.parse((String)payload, RAGEmbeddingRecipePayloadParams.class);
    }

    @Override
    public EmbeddingRecipePayloadBaseParams getRecipeBaseDesc() {
        return this.desc;
    }

    @Override
    protected List<String> getIgnoredVersionFiles() {
        return Arrays.asList("kb.json", "recipe_settings.json", "version_info.json");
    }

    @Override
    protected boolean shouldClearKnownledgeBankIfPreviousVersionNotComplete(EmbeddingRecipePayloadBaseParams.VectorStoreUpdateMethod vectorStoreUpdateMethod) {
        if (vectorStoreUpdateMethod.isSmart) {
            logger.info((Object)"Smart Update vector store method: need to clear the KB to create a new record manager");
            return true;
        }
        return false;
    }

    @Override
    protected boolean shouldClearKnowledgeBankAfterSettingsChange(RetrievableKnowledge currentRk, RetrievableKnowledge previousRk, EmbeddingRecipePayloadBaseParams currentDesc, EmbeddingRecipePayloadBaseParams previousDesc, File previousVersionFolder) {
        boolean doClear = false;
        List<String> previousMetadataCols = previousDesc.metadataColumns.stream().map(col -> col.column).toList();
        for (EmbeddingRecipePayloadBaseParams.MetadataColumn col2 : currentDesc.metadataColumns) {
            if (previousMetadataCols.contains(col2.column)) continue;
            doClear = true;
            break;
        }
        if (!previousDesc.getVectorStoreUpdateMethod().isSmart && currentDesc.getVectorStoreUpdateMethod().isSmart) {
            doClear = true;
        }
        if (previousDesc.getVectorStoreUpdateMethod().isSmart && currentDesc.getVectorStoreUpdateMethod().isSmart && !Objects.equals(currentRk.sourceIdColumn, currentDesc.sourceIdColumn)) {
            doClear = true;
        }
        if (!Objects.equals(previousDesc.knowledgeColumn, currentDesc.knowledgeColumn)) {
            doClear = true;
        }
        if (!Objects.equals(previousDesc.securityTokensColumn, currentDesc.securityTokensColumn)) {
            doClear = true;
        }
        if (currentDesc.getVectorStoreUpdateMethod().isSmart && !Objects.equals((Object)currentRk.vectorStoreType, (Object)previousRk.vectorStoreType)) {
            doClear = true;
        }
        if (currentDesc.getVectorStoreUpdateMethod() != EmbeddingRecipePayloadBaseParams.VectorStoreUpdateMethod.OVERWRITE && !Objects.equals(previousRk.embeddingLLMId, currentRk.embeddingLLMId)) {
            doClear = true;
        }
        if (previousDesc.documentSplittingMode != currentDesc.documentSplittingMode) {
            doClear = true;
        } else if (previousDesc.documentSplittingMode == EmbeddingRecipePayloadBaseParams.DocumentSplittingMode.CHARACTERS_BASED && (previousDesc.chunkSizeCharacters != currentDesc.chunkSizeCharacters || previousDesc.chunkOverlapCharacters != currentDesc.chunkOverlapCharacters)) {
            doClear = true;
        }
        return doClear;
    }

    @Override
    public void run() throws Exception {
        logger.info((Object)"RAG Embedding recipe runner started");
        SerializedRecipe.RecipeOutput ro = (SerializedRecipe.RecipeOutput)this.recipe.getModel().getOutputsForRole("knowledge_bank").stream().findFirst().orElseThrow(() -> new IllegalArgumentException("Knowledge bank output not found"));
        this.rk = (RetrievableKnowledge)this.retrievableKnowledgeDAO.getMandatory(ro.getLoc(this.recipe.getProjectKey()));
        this.checkKnowledgeBanksAllowed(this.rk);
        SerializedRecipe.RecipeInput ri = (SerializedRecipe.RecipeInput)this.recipe.getModel().getInputsForRole("main").stream().findFirst().orElseThrow(() -> new IllegalArgumentException("dataset input not found"));
        SerializedDataset sd = (SerializedDataset)this.datasetsDAO.getMandatory(ri.getLoc(this.recipe.getProjectKey()));
        Dataset dataset = Dataset.fromSerialized(sd);
        File baseFolder = DKUApp.getFile((String[])new String[]{"knowledge-banks", this.rk.projectKey, this.rk.id});
        MLPaths.createIfNeededFolderAndRestrictPermissions(baseFolder);
        this.desc.clearVectorStore = this.shouldClearKnowledgeBank(this.rk);
        String version = "" + System.currentTimeMillis();
        File rkFolder = DKUApp.getFile((File)baseFolder, (String[])new String[]{"versions", version});
        MLPaths.createIfNeededFolderAndRestrictPermissions(rkFolder);
        Set metadataCols = this.desc.metadataColumns.stream().map(col -> col.column).collect(Collectors.toSet());
        this.rk.metadataColumnsSchema = dataset.getSchema().columns.stream().filter(sc -> metadataCols.contains(sc.getName())).collect(Collectors.toList());
        if (!this.desc.clearVectorStore && this.desc.vectorStoreUpdateMethod != EmbeddingRecipePayloadBaseParams.VectorStoreUpdateMethod.OVERWRITE) {
            this.copyPreviousVersionFolder(this.rk, rkFolder);
        }
        ContainerExecSelection containerSelection = null;
        ContainerExecRuntimeConfig clusterContainerConfig = null;
        InitializableAbortableRecipeRunner runner = this.createRunner(containerSelection, clusterContainerConfig, baseFolder, rkFolder, dataset.getFullName());
        SpringUtils.getInstance().autowire((Object)runner);
        runner.init();
        this.abortableRunner = runner;
        runner.run();
        this.updateMetadataColsSchema(this.rk);
        this.updateKBVersionOnDisk(this.rk, rkFolder, version);
    }

    private InitializableAbortableRecipeRunner createRunner(ContainerExecSelection containerSelection, final ContainerExecRuntimeConfig clusterContainerConfig, final File baseFolder, final File rkFolder, final String inputDatasetFullName) throws Exception {
        File additionalLogsDir = FlowJobUtils.getJobMadeDir("rag-embedding-recipe", "additional-logs");
        File mainLogFile = FlowJobUtils.getJobTouchedFile("rag-embedding-recipe", "python.log");
        final Optional<CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo> remoteVectorStoreConnectionInfo = this.getRemoteVectorStoreConnectionInfo(this.authCtx, this.rk.connection);
        JobContext.getCurrentActivitySummary().engineType = "DSS";
        return new AbstractPythonRecipeRunner(this.activity){

            @Override
            public void run() throws Exception {
                FilesystemACLUtils.grantFSFullACLs(this.authCtxService.getAuthCtx(), this.projectKey, baseFolder);
                String envName = new CodeEnvSelector().selectForPythonRecipe(this.recipe.getProjectKey(), ((RAGEmbeddingRecipeRunner)RAGEmbeddingRecipeRunner.this).rk.envSelection);
                logger.info((Object)("Run embedding in code env " + StringUtils.defaultIfBlank((String)envName, (String)"built-in")));
                try (AutoDelete outputTmpDir = FlowJobUtils.getTmpFolder("rag-embedding-recipe", "pyrun");){
                    if (clusterContainerConfig == null) {
                        JSON.prettyToFile((Object)RAGEmbeddingRecipeRunner.this.desc, (File)new File(rkFolder, "recipe_settings.json"));
                        JSON.prettyToFile((Object)RAGEmbeddingRecipeRunner.this.rk, (File)new File(rkFolder, "kb.json"));
                        this.executeModule(envName, (File)outputTmpDir, "dataiku.llm.rag.rag_embedding_recipe", false, rkFolder.getAbsolutePath(), inputDatasetFullName);
                    }
                }
            }

            @Override
            public void init() throws Exception {
            }

            @Override
            public void enrichModuleProcess(ProcessBuilder builder) {
                remoteVectorStoreConnectionInfo.ifPresent(connectionLocationInfo -> builder.environment().put("DKU_KB_CONNECTION_INFO", JSON.json((Object)connectionLocationInfo)));
            }
        };
    }
}

