/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.vectorstore;

import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.filtering.SimpleFilter;
import com.dataiku.dip.agents.tools.vectorstore.VectorStoreQueryTool;
import com.dataiku.dip.agents.tools.vectorstore.VectorStoreQueryToolServerAPI;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dataflow.exec.CodeBasedRecipeDatasetInfoHelper;
import com.dataiku.dip.io.SimplePythonKernelFactory;
import com.dataiku.dip.llm.LLMRelatedPoolablePythonServer;
import com.dataiku.dip.llm.retrieval.LangchainBasedRAGServer;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.JsonUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.base.Strings;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.apache.log4j.Logger;

public class VectorStoreQueryToolServer
extends LLMRelatedPoolablePythonServer
implements VectorStoreQueryToolServerAPI {
    protected final DSSAuthCtx authCtx;
    protected final String projectKey;
    protected final RetrievableKnowledge rk;
    protected final String envName;
    protected VectorStoreQueryTool.VectorStoreQueryToolParams params;
    public String containerConfName;
    public String clusterId;
    private static final Logger logger = Logger.getLogger((String)"dku.agents.tools.vectorstore");

    public VectorStoreQueryToolServer(AuthCtx authCtx, String projectKey, RetrievableKnowledge rk, VectorStoreQueryTool.VectorStoreQueryToolParams params, CodeEnvSelection codeEnvSelection, String containerConfName, String clusterId) {
        super("dku.agents.tools.vectorstore", "vstool-query-" + rk.projectKey + "-" + rk.id + "-" + SecretKeyGenerator.generateSmall());
        this.authCtx = (DSSAuthCtx)authCtx;
        this.projectKey = projectKey;
        this.rk = rk;
        this.params = params;
        this.containerConfName = containerConfName;
        try {
            this.envName = new CodeEnvSelector().selectForPythonRecipe(projectKey, codeEnvSelection);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public CompletableFuture<AgentToolRunner.AgentToolOutput> runAsync(AgentToolRunner.AgentToolInput input) {
        RunToolCommand req = new RunToolCommand();
        req.searchQuery = input.input.getAsJsonObject().get("searchQuery").getAsString();
        req.params = this.params;
        if (this.params.enforceDocumentLevelSecurity) {
            JsonElement securityTokensElt = JsonUtils.getOrNull((JsonElement)input.context, "callerSecurityTokens");
            if (securityTokensElt == null) {
                throw new RuntimeException("Unable to process query: The document-level security requires valid security tokens. Please ensure your security tokens are included with your query context.");
            }
            req.securityTokens = (List)JSON.parse((String)JSON.json((Object)securityTokensElt), JSON.StringList.class);
            if (req.securityTokens.isEmpty()) {
                AgentToolRunner.AgentToolOutput toolOutput = new AgentToolRunner.AgentToolOutput();
                JsonObject searchOutput = new JsonObject();
                toolOutput.output = searchOutput;
                searchOutput.add("documents", (JsonElement)new JsonArray());
                return CompletableFuture.completedFuture(toolOutput);
            }
        }
        if (this.params.filter != null && this.params.performFiltering) {
            req.filter = SimpleFilter.fromComplexFilter(this.params.filter, Optional.empty());
        }
        if (this.params.allowDynamicFiltering) {
            JsonArray callerFilters = JsonUtils.getOrEmptyArr(input.context, "callerFilters");
            logger.info((Object)("Adding user defined filters: " + JSON.json((Object)callerFilters)));
            req.callerFilters = (List)JSON.parse((String)JSON.json((Object)callerFilters), List.class);
        }
        if (this.params.allowAgentInferredFiltering && input.input.getAsJsonObject().get("filter") != null) {
            JsonElement jsonFilter = JsonUtils.getOrNull(input.input, "filter");
            SimpleFilter simpleFilter = (SimpleFilter)JSON.parse((JsonElement)jsonFilter, SimpleFilter.class);
            logger.info((Object)("Adding agent-inferred filter: " + JSON.json((Object)jsonFilter)));
            LangchainBasedRAGServer.RagQueryFilter rqf = new LangchainBasedRAGServer.RagQueryFilter();
            rqf.filter = simpleFilter;
            if (req.callerFilters == null) {
                req.callerFilters = Collections.singletonList(rqf);
            } else {
                req.callerFilters.add(rqf);
            }
        }
        return this.kernel.getLink().getAsyncLink().asyncStreamRequest((Object)req, AgentToolRunner.AgentToolOutput.class).last().doOnNext(scr -> {
            logger.info((Object)("Got response: " + JSON.json((Object)scr)));
            if (scr.sources != null && scr.sources.size() == 1) {
                String displayName = this.rk.name;
                scr.sources.get((int)0).toolCallDescription = "Searched Knowledge Bank " + displayName + " for '" + req.searchQuery + "'";
            }
        }).toFuture();
    }

    public void init(boolean devKernel, boolean devMode, File logBaseDir, AnyLoc toolSmartId) throws Exception {
        this.kernel = SimplePythonKernelFactory.prepareKernel(this.authCtx, this.projectKey, GeneralSettingsDAO.CGrouppableProcessType.ML_KERNEL, this.envName, "dataiku.llm.agent_tools.vector_store_query_tool_server", false, this.containerConfName, this.kernelId);
        StartCommand command = new StartCommand();
        command.knowledgeBankFullId = this.rk.getFullId();
        command.toolRef = toolSmartId.getFullName();
        this.setLogsWithRotation(logBaseDir, devKernel, devMode);
        HashMap<String, String> extraEnv = new HashMap<String, String>();
        if (!Strings.isNullOrEmpty((String)this.rk.connection)) {
            extraEnv.put("DKU_KB_CONNECTION_INFO", JSON.json((Object)new CodeBasedRecipeDatasetInfoHelper().getConnectionInfoUnsafe_NT(this.authCtx, this.rk.connection, this.projectKey)));
        }
        extraEnv.put("DKU_CURRENT_PROJECT_KEY", this.projectKey);
        this.kernel.withExtraEnv(extraEnv);
        this.kernel.start();
        logger.info((Object)("Sending start command to server: " + JSON.json((Object)command)));
        this.kernel.getLink().getAsyncLink().request((Object)command, JsonElement.class);
    }

    @Override
    public SmartLogTail getKernelLog() {
        return this.kernel.getSmartLogTailBuilder().get();
    }

    public static class RunToolCommand {
        public final String type = "run-tool";
        public String searchQuery;
        public VectorStoreQueryTool.VectorStoreQueryToolParams params;
        public List<String> securityTokens;
        public List<LangchainBasedRAGServer.RagQueryFilter> callerFilters;
        public SimpleFilter filter;
    }

    public static class StartCommand {
        public final String type = "start";
        public String knowledgeBankFullId;
        public String toolRef;
    }
}

