/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.llmmesh;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.LLMAllocationTagsUtils;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMMeshClientFactory;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonSyntaxException;
import java.io.IOException;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMMeshLLMQueryTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "LLMMeshLLMQuery";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            LLMStructuredRef llmStructuredRef;
            Params p = tool.getParamsCopyAs(Params.class);
            if (p != null && StringUtils.isNotBlank((CharSequence)p.llmId) && (llmStructuredRef = LLMStructuredRef.decodeId(p.llmId)) != null && llmStructuredRef.savedModelSmartId != null) {
                return Lists.newArrayList((Object[])new SavedModel.AgentDependency[]{new SavedModel.AgentDependency(ITaggingService.TaggableType.SAVED_MODEL, llmStructuredRef.savedModelSmartId)});
            }
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public void checkAccessDependency(AuthCtx authCtx, AgentTool tool) throws IOException, ForbiddenObjectException {
            Params params = tool.getParamsCopyAs(Params.class);
            if (StringUtils.isBlank((CharSequence)params.llmId)) {
                logger.warn((Object)"No LLM selected. Skipping access check to dependency.");
                return;
            }
            LLMStructuredRef llmRef = LLMStructuredRef.decodeId(params.llmId);
            if (llmRef.savedModelSmartId != null) {
                AnyLoc llmLoc = AnyLoc.resolveSmart(tool.projectKey, llmRef.savedModelSmartId);
                ((ProjectsService)SpringUtils.getBean(ProjectsService.class)).failIfLocNotAvailableInProject(ITaggingService.TaggableType.SAVED_MODEL, llmLoc, tool.projectKey);
            }
        }

        @Override
        public Set<String> listConnectionNames(AgentTool tool) {
            LLMStructuredRef llmStructuredRef;
            Params p = tool.getParamsCopyAs(Params.class);
            if (p != null && StringUtils.isNotBlank((CharSequence)p.llmId) && (llmStructuredRef = LLMStructuredRef.decodeId(p.llmId)) != null && llmStructuredRef.connection != null) {
                return Set.of(llmStructuredRef.connection);
            }
            return new HashSet<String>();
        }

        @Override
        public boolean remapConnections(AgentTool tool, Map<String, String> replacements) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p == null || p.llmId == null) {
                return false;
            }
            LLMStructuredRef ref = LLMStructuredRef.decodeId(p.llmId);
            if (ref == null || ref.connection == null) {
                return false;
            }
            String newConnection = replacements.get(ref.connection);
            if (newConnection == null) {
                return false;
            }
            ref.setConnection(newConnection);
            p.llmId = ref.encodeToId();
            tool.setParams(p);
            return true;
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            td.description = "Asks a question to an agent.";
            if (StringUtils.isNotBlank((CharSequence)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/llm-mesh/llm/query", "Query an LLM using the LLM Mesh");
            td.inputSchema.properties.put("question", JsonSchemaElement.string("the question to ask"));
            return td;
        }

        @Override
        public AgentToolMeta.ToolCallDescription getToolCallDescription_NT(AuthCtx authCtx, String projectKey, AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) throws Exception {
            String modelType;
            String modelName;
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            LLMStructuredRef ref = LLMStructuredRef.decodeId(p.llmId);
            if (ref.savedModelSmartId != null) {
                SavedModel sm;
                TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
                AnyLoc modelLoc = AnyLoc.resolveSmart(tool.projectKey, ref.savedModelSmartId);
                try (Transaction t = transactionService.beginRead(IsolationLevel.YOLO);){
                    sm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getMandatory(modelLoc);
                }
                modelName = sm.name;
                switch (sm.savedModelType) {
                    case TOOLS_USING_AGENT: {
                        modelType = "visual agent";
                        break;
                    }
                    case PYTHON_AGENT: {
                        modelType = "python agent";
                        break;
                    }
                    case PLUGIN_AGENT: {
                        modelType = "plugin python agent";
                        break;
                    }
                    case RETRIEVAL_AUGMENTED_LLM: {
                        modelType = "retrieval-augmented LLM";
                        break;
                    }
                    case LLM_GENERIC: {
                        modelType = "fine-tuned model";
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException(String.format("Unexpected model type %s for LLMMeshLLMQueryTool", new Object[]{sm.savedModelType}));
                    }
                }
            } else {
                modelType = "LLM";
                modelName = (String)StringUtils.firstNonBlank((CharSequence[])new String[]{ref.model, ref.deployment});
            }
            Object description = String.format("I'm about to query %s <b>%s</b>", modelType, modelName);
            if (ref.connection != null) {
                description = (String)description + String.format(" from connection <b>%s</b>", ref.connection);
            }
            description = (String)description + " with the following question.\n";
            description = (String)description + "\n";
            description = (String)description + "Do you want to proceed?";
            return new AgentToolMeta.ToolCallDescription((String)description);
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) throws CodedException {
            return new Runner(authCtx, projectKey, tool.projectKey, tool.getParamsCopyAs(Params.class));
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.llm");

    public static class Runner
    implements AgentToolRunner {
        @Autowired
        private ComputeResourceUsageReportingService cruReportingService;
        @Autowired
        private AuditTrailService auditTrailService;
        private final AuthCtx authCtx;
        private final String contextProjectKey;
        private final String sourceProjectKey;
        private final Params params;

        public Runner(AuthCtx authCtx, String contextProjectKey, String sourceProjectKey, Params p) {
            this.authCtx = authCtx;
            this.contextProjectKey = contextProjectKey;
            this.sourceProjectKey = sourceProjectKey;
            this.params = p;
        }

        @Override
        public void init() throws IOException {
            SpringUtils.getInstance().autowire((Object)this);
        }

        private LLMStructuredRef getLLMRef() {
            if (this.params.llmId == null) {
                throw new IllegalArgumentException("No LLM/Agent selected");
            }
            LLMStructuredRef llmRef = LLMStructuredRef.decodeId(this.params.llmId);
            if (!Objects.equals(this.contextProjectKey, this.sourceProjectKey) && llmRef.savedModelSmartId != null) {
                AnyLoc llmLoc = AnyLoc.resolveSmart(this.sourceProjectKey, llmRef.savedModelSmartId);
                llmRef.savedModelSmartId = llmLoc.getFullName();
                llmRef.id = llmRef.encodeToId();
            }
            return llmRef;
        }

        /*
         * Unable to fully structure code
         */
        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            LLMMeshLLMQueryTool.logger.debug((Object)("Running with input " + JSON.log((Object)input)));
            question = this.safeReadStringArgument(input, new Object[]{"question"});
            llmRef = this.getLLMRef();
            connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.sourceProjectKey, llmRef);
            usageTimeGuardrailsPipelineSettings = null;
            guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, usageTimeGuardrailsPipelineSettings);
            llmMeshClient = LLMMeshClientFactory.get(this.authCtx, this.contextProjectKey, llmRef, guardrailsPipelineSettings, null, 1);
            try {
                block40: {
                    block41: {
                        enrichedLLMRef = llmMeshClient.getEnrichedRef();
                        sm = null;
                        query = new LLMClient.SingleCompletionQuery();
                        if (this.params.forwardContext.booleanValue()) {
                            query.context = input.context;
                        }
                        if (StringUtils.isNotBlank((CharSequence)this.params.systemPromptPrepend)) {
                            query.messages.add(new LLMClient.ChatMessage("system", this.params.systemPromptPrepend));
                        }
                        query.messages.add(new LLMClient.ChatMessage("user", question));
                        restoredPartialOutput = null;
                        restoredStashedSources = null;
                        if (input.memoryFragment == null) break block40;
                        if (enrichedLLMRef.type != LLMStructuredRef.LLMType.SAVED_MODEL_AGENT) {
                            throw new IllegalArgumentException("Memory fragments are only supported for agents");
                        }
                        if (input.memoryFragment.messages == null || input.memoryFragment.messages.isEmpty() || input.memoryFragment.messages.size() > 2) {
                            throw new IllegalArgumentException("Invalid memory fragment structure");
                        }
                        wrappedPartialOutput = null;
                        wrappedMemoryFragment = null;
                        for (LLMClient.ChatMessage wrapped : input.memoryFragment.messages) {
                            if (wrapped.memoryFragmentTarget == null) {
                                wrappedPartialOutput = wrapped;
                                continue;
                            }
                            wrappedMemoryFragment = wrapped;
                        }
                        if (wrappedPartialOutput != null) {
                            if (!wrappedPartialOutput.role.equals("assistant")) {
                                throw new IllegalArgumentException(String.format("Expected the memory fragment to contain a message with role 'assistant', got %s", new Object[]{wrappedPartialOutput.role}));
                            }
                            restoredPartialOutput = wrappedPartialOutput.getText();
                        }
                        if (wrappedMemoryFragment == null) break block41;
                        transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
                        sm = LLMClientFactory.loadSavedModelContext(this.contextProjectKey, llmRef, transactionService).sm();
                        var18_23 = wrappedMemoryFragment.memoryFragmentTarget;
                        if (!(var18_23 instanceof LLMClient.AgentHierarchyEntry)) ** GOTO lbl-1000
                        target = (LLMClient.AgentHierarchyEntry)var18_23;
                        if (Objects.equals(target.agentId, sm.id)) {
                            v0 = true;
                        } else lbl-1000:
                        // 2 sources

                        {
                            v0 = validTarget = false;
                        }
                        if (!validTarget) {
                            throw new IllegalArgumentException("Incorrect nested memory fragment target");
                        }
                        memoryFragmentMessage = new LLMClient.ChatMessage();
                        memoryFragmentMessage.role = "memoryFragment";
                        memoryFragmentMessage.memoryFragment = wrappedMemoryFragment.memoryFragment;
                        query.messages.add(memoryFragmentMessage);
                    }
                    restoredStashedSources = input.memoryFragment.stashedSources;
                }
                if (input.toolValidationResponses != null && !input.toolValidationResponses.isEmpty()) {
                    if (enrichedLLMRef.type != LLMStructuredRef.LLMType.SAVED_MODEL_AGENT) {
                        throw new IllegalArgumentException("Tool validation requests are only supported for agents");
                    }
                    toolValidationRequestsMessage = new LLMClient.ChatMessage();
                    toolValidationRequestsMessage.role = "toolValidationRequests";
                    toolValidationRequestsMessage.toolValidationRequests = input.toolValidationRequests;
                    query.messages.add(toolValidationRequestsMessage);
                    toolValidationResponsesMessage = new LLMClient.ChatMessage();
                    toolValidationResponsesMessage.role = "toolValidationResponses";
                    toolValidationResponsesMessage.toolValidationResponses = input.toolValidationResponses;
                    query.messages.add(toolValidationResponsesMessage);
                }
                responses = null;
                _ignored = FutureProgress.pushAutoCloseableState((String)"Querying LLM", (double)1.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.RECORDS);
                try {
                    LLMAuditHelper.emitToolValidationAuditsIfNeeded(this.auditTrailService, enrichedLLMRef, llmMeshClient.getConnection(), query);
                    responses = llmMeshClient.completeQueries(Lists.newArrayList((Object[])new LLMClient.SingleCompletionQuery[]{query}), this.params.completionSettings);
                }
                finally {
                    if (_ignored != null) {
                        _ignored.close();
                    }
                }
                cru = llmMeshClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
                if (cru != null) {
                    LLMAllocationTagsUtils.addAllocationTagsToCRU(query, cru);
                    this.cruReportingService.reportComplete(cru);
                }
                if (!Runner.$assertionsDisabled && responses.size() != 1) {
                    throw new AssertionError();
                }
                resp = responses.get(0);
                if (!resp.ok) {
                    throw new RuntimeException("LLM query failed: " + resp.errorMessage);
                }
                o = new AgentToolRunner.AgentToolOutput();
                fullOutput = null;
                if (restoredPartialOutput != null) {
                    fullOutput = restoredPartialOutput;
                }
                if (resp.text != null) {
                    fullOutput = fullOutput != null ? (String)fullOutput + resp.text : resp.text;
                }
                parsedSources = new ArrayList<E>();
                if (this.params.returnSources.booleanValue() && resp.additionalInformation != null && (sourcesObject = resp.additionalInformation.get("sources")) != null && sourcesObject.isJsonArray()) {
                    tempParsedSources = new ArrayList<E>();
                    sourcesObject.getAsJsonArray().forEach((Consumer<JsonElement>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)V, lambda$run$0(java.util.List com.google.gson.JsonElement ), (Lcom/google/gson/JsonElement;)V)(tempParsedSources));
                    parsedSources = tempParsedSources;
                }
                allSources = new ArrayList<E>();
                if (restoredStashedSources != null) {
                    allSources = restoredStashedSources;
                }
                allSources.addAll(parsedSources);
                o.output = JF.obj().with("response", (String)fullOutput).get();
                o.sources = allSources;
                if (this.params.returnArtifacts.booleanValue() && resp.artifacts != null) {
                    o.artifacts = resp.artifacts;
                }
                if (resp.toolValidationRequests != null) {
                    if (enrichedLLMRef.type != LLMStructuredRef.LLMType.SAVED_MODEL_AGENT) {
                        throw new IllegalArgumentException("Tool validation requests are only supported for agents");
                    }
                    o.toolValidationRequests = resp.toolValidationRequests;
                    if (resp.memoryFragment != null) {
                        if (sm == null) {
                            transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
                            sm = LLMClientFactory.loadSavedModelContext(this.contextProjectKey, llmRef, transactionService).sm();
                        }
                        target = new LLMClient.AgentHierarchyEntry();
                        target.agentName = sm.name;
                        target.agentId = sm.id;
                        o.createMemoryFragmentIfNotExists();
                        wrappedMemoryFragment = LLMClient.ChatMessage.wrapMemoryFragment(resp.memoryFragment, target);
                        o.memoryFragment.messages.add(wrappedMemoryFragment);
                    }
                    if (StringUtils.isNotBlank((CharSequence)resp.text)) {
                        o.createMemoryFragmentIfNotExists();
                        wrappedPartialOutput = new LLMClient.ChatMessage("assistant", resp.text);
                        o.memoryFragment.messages.add(wrappedPartialOutput);
                    }
                    if (!allSources.isEmpty()) {
                        o.createMemoryFragmentIfNotExists();
                        o.memoryFragment.stashedSources = allSources;
                    }
                }
                o.trace = resp.trace;
                LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, enrichedLLMRef, llmMeshClient.getConnection(), query, resp);
                var20_26 = o;
                return var20_26;
            }
            finally {
                if (llmMeshClient != null) {
                    llmMeshClient.close();
                }
            }
        }

        @Override
        public void close() throws Exception {
        }

        private static /* synthetic */ void lambda$run$0(List tempParsedSources, JsonElement jsonSource) {
            try {
                tempParsedSources.add((AgentToolRunner.Source)JSON.parse((String)jsonSource.toString(), AgentToolRunner.Source.class));
            }
            catch (JsonSyntaxException e) {
                logger.warn((Object)"Failed to parse source from LLM call", (Throwable)e);
            }
        }
    }

    public static class Params
    implements AgentToolParams {
        public String llmId;
        public String systemPromptPrepend;
        public LLMClient.CompletionSettings completionSettings = new LLMClient.CompletionSettings();
        public Boolean forwardContext = false;
        public Boolean returnArtifacts = false;
        public Boolean returnSources = false;
    }
}

