/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.llm;

import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.AgentToolsCRUDService;
import com.dataiku.dip.agents.tools.AgentToolsDAO;
import com.dataiku.dip.agents.tools.AgentToolsRegistry;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.ExceptionWithLogTail;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.scheduler.reports.TemplatesDAO;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditNotNeeded;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.services.ConflictCheckService;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.InterestsService;
import com.dataiku.dip.server.services.NavigatorService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.Id;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.google.gson.JsonObject;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class AgentToolsController
extends DIPInternalControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private AgentToolsDAO agentToolsDAO;
    @Autowired
    private AgentToolsCRUDService crudService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private UIAuthService authService;
    @Autowired
    private InterestsService interestsService;
    @Autowired
    private NavigatorService navigatorService;
    @Autowired
    private TaggableObjectsService taggableObjectsService;
    @Autowired
    private LicenseEnforcementService licenseEnforcementService;
    @Autowired
    private ConflictCheckService conflictCheckService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;

    @AuditedCall(value={"msgType", "agent-tools-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/agent-tools/list"})
    @ResponseBody
    public List<AgentTool> list(HttpServletRequest req, @RequestParam String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            List<AgentTool> list = this.agentToolsDAO.listUnsafe(projectKey);
            return list;
        }
    }

    @AuditedCall(value={"msgType", "agent-tools-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/agent-tools/list-heads"})
    @ResponseBody
    public List<AgentTool.AgentToolListItem> listHeads(HttpServletRequest req, @RequestParam String projectKey) throws Exception {
        AuthCtx authCtx;
        ArrayList<AgentTool.AgentToolListItem> heads = new ArrayList<AgentTool.AgentToolListItem>();
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_DASHBOARDS);
            List list = this.agentToolsDAO.listUnsafe(projectKey);
            for (AgentTool at : list) {
                AgentTool.AgentToolListItem item = new AgentTool.AgentToolListItem(at);
                this.taggableObjectsService.setEditionInfoFromTags(at, item);
                heads.add(item);
            }
        }
        this.interestsService.enrichListItems(authCtx.getAssociatedDSSUser(), projectKey, heads);
        return heads;
    }

    @AuditedCall(value={"msgType", "agent-tools-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/agent-tools/list-available"})
    @ResponseBody
    public List<AgentTool.AgentToolListItem> listAvailable(HttpServletRequest req, @RequestParam String projectKey) throws Exception {
        ArrayList<AgentTool.AgentToolListItem> heads = new ArrayList<AgentTool.AgentToolListItem>();
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_DASHBOARDS);
            List list = this.agentToolsDAO.listUnsafe(projectKey);
            for (AgentTool at : list) {
                AgentTool.AgentToolListItem item = new AgentTool.AgentToolListItem(at);
                this.taggableObjectsService.setEditionInfoFromTags(at, item);
                heads.add(item);
            }
        }
        return heads;
    }

    @AuditedCall(value={"msgType", "agent-tool-get", "projectKey", "${projectKey}", "id", "${id}"})
    @RequestMapping(value={"/api/agent-tools/get"}, method={RequestMethod.GET})
    @ResponseBody
    public AgentTool getAgentTool(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String id) throws IOException, DKUSecurityException {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            AgentTool agentTool = this.crudService.getMandatoryUnsafe(projectKey, id);
            return agentTool;
        }
    }

    @AuditedCall(value={"msgType", "agent-tool-get-descriptor", "projectKey", "${projectKey}", "id", "${id}"})
    @RequestMapping(value={"/api/agent-tools/get-descriptor"}, method={RequestMethod.GET})
    @ResponseBody
    public AgentToolMeta.ToolDescriptor getAgentToolDescriptor(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String id) throws Exception {
        AgentTool tool;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            tool = this.crudService.getMandatoryUnsafe(projectKey, id);
        }
        AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
        AgentToolMeta.ToolDescriptor desc = meta.getResultingDescriptor(authCtx, projectKey, tool);
        return desc;
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/agent-tools/has-sample-query"}, method={RequestMethod.GET})
    @ResponseBody
    public boolean hasSampleQuery(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String id) throws Exception {
        AgentTool tool;
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            tool = this.crudService.getMandatoryUnsafe(projectKey, id);
        }
        AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
        return meta.supportsLoadSampleQuery;
    }

    @AuditedCall(value={"projectKey", "${projectKey}", "id", "${id}"})
    @RequestMapping(value={"/api/agent-tools/load-sample-query"}, method={RequestMethod.GET})
    @ResponseBody
    public JsonObject loadSampleQuery(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String id) throws Exception {
        AgentTool tool;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            tool = this.crudService.getMandatoryUnsafe(projectKey, id);
        }
        AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
        return meta.loadSampleQuery(authCtx, projectKey, tool);
    }

    @AuditedCall(value={"msgType", "agent-tool-get-metadata", "projectKey", "${projectKey}", "agentToolId", "${id}"})
    @RequestMapping(value={"/api/agent-tools/get-full-info"})
    @ResponseBody
    public NavigatorService.AgentToolFullInfo getFullInfo(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String id) throws Exception {
        NavigatorService.AgentToolFullInfo info;
        AuthCtx u;
        try (Transaction t = this.transactionService.beginRead();){
            u = this.authService.getMandatoryUser(req);
            AnyLoc loc = AnyLoc.resolveSmart(projectKey, id);
            this.projectsService.failIfNoTaggableObjectReadUseAccess(u, ITaggingService.TaggableType.PROMPT_STUDIO, loc, projectKey);
            info = this.navigatorService.getAgentToolFullInfo(projectKey, id);
        }
        this.navigatorService.addInfo_NT(info, u);
        return info;
    }

    @AuditedCall(value={"msgType", "agent-tool-get-usage", "projectKey", "${projectKey}", "agentToolId", "${id}"})
    @RequestMapping(value={"/api/agent-tools/get-usage"})
    @ResponseBody
    public AgentToolUsage getUsage(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String id) throws Exception {
        List<SavedModel> visualAgents;
        AgentTool tool;
        AgentToolUsage usage = new AgentToolUsage();
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            tool = this.crudService.getMandatory(projectKey, id);
            visualAgents = this.savedModelsDAO.listUnsafe(projectKey).stream().filter(sm -> sm.savedModelType == SavedModel.SavedModelType.TOOLS_USING_AGENT).toList();
        }
        usage.agents = new ArrayList<SavedModel>();
        block5: for (SavedModel sm2 : visualAgents) {
            SavedModel.SavedModelInlineVersion activeVersion = sm2.getActiveSaveModelInlineVersion();
            if (activeVersion == null || activeVersion.toolsUsingAgentSettings == null) continue;
            for (SavedModel.UsedTool toolObj : activeVersion.toolsUsingAgentSettings.tools) {
                if (!StringUtils.equals((String)toolObj.toolRef, (String)tool.id)) continue;
                usage.agents.add(sm2);
                continue block5;
            }
        }
        return usage;
    }

    @AuditedCall(value={"msgType", "agent-tool-test", "projectKey", "${projectKey}", "agentToolId", "${id}"})
    @RequestMapping(value={"/api/agent-tools/test"})
    @ResponseBody
    public ToolQuickTestResponse toolQuickTest(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String id, @RequestParam String query) throws Exception {
        ToolQuickTestResponse ret;
        block39: {
            AgentTool tool;
            AuthCtx authCtx;
            try (Transaction t = this.transactionService.beginRead();){
                authCtx = this.authService.getMandatoryUser(req);
                this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
                tool = this.crudService.getMandatory(projectKey, id);
            }
            ret = new ToolQuickTestResponse();
            AgentToolRunner.AgentToolInput input = (AgentToolRunner.AgentToolInput)JSON.parse((String)query, AgentToolRunner.AgentToolInput.class);
            GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            AuditTrailService.EmittableAuditObj auditObj = this.auditTrailService.generic("agent-tool-run");
            auditObj.with("projectKey", projectKey).with("agentToolId", id);
            try (LLMClient.LLMMeshTraceSpan span = LLMClient.LLMMeshTraceSpan.start("DKU_MANAGED_TOOL_CALL");){
                AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
                try (AgentToolRunner runner = meta.buildRunner(authCtx, projectKey, tool, true);){
                    AgentToolRunner.AgentToolOutput output;
                    SmartLogTail slt = new SmartLogTail();
                    SimpleDateFormat sdf = new SimpleDateFormat("[yyyy/MM/dd-HH:mm:ss.SSS] ");
                    slt.appendLine(sdf.format(new Date()) + "Initializing tool runner");
                    long beforeInit = System.currentTimeMillis();
                    try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((String)("dku.agents.tools.invoke.initRunner.byType." + tool.type));){
                        runner.init();
                    }
                    long afterInit = System.currentTimeMillis();
                    long initTimeMS = afterInit - beforeInit;
                    slt.appendLine(sdf.format(new Date()) + "Runner initialized in " + initTimeMS + "ms");
                    logger.debug((Object)("Tool input " + JSON.pretty((Object)input)));
                    if (gs.generativeAISettings.agentsToolsSettings.auditToolsInputs) {
                        auditObj.with("toolInput", JSON.toJsonObject((Object)input.input));
                    }
                    auditObj.with("initTimeMS", (Number)initTimeMS);
                    slt.appendLine(sdf.format(new Date()) + "Running tool with input " + JSON.pretty((Object)input));
                    try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((String)("dku.agents.tools.invoke.run.byType." + tool.type));){
                        output = runner.run(input);
                    }
                    logger.debug((Object)("Tool output " + JSON.pretty((Object)output)));
                    long afterRun = System.currentTimeMillis();
                    long runTimeMS = afterRun - afterInit;
                    auditObj.with("outcome", "success").with("runTimeMS", (Number)runTimeMS);
                    if (gs.generativeAISettings.agentsToolsSettings.auditToolsOutputs) {
                        auditObj.with("toolOutput", JSON.toJsonObject((Object)output));
                    }
                    if (output.trace != null) {
                        span.children.add(output.trace);
                    }
                    output.trace = span;
                    ret.response = output;
                    ret.fullTrace = span;
                    ret.response.trace = null;
                    SmartLogTail kernelSlt = runner.getKernelLog();
                    if (kernelSlt != null) {
                        slt.append(kernelSlt);
                    }
                    slt.appendLine(sdf.format(new Date()) + "Tool ran in " + runTimeMS + "ms");
                    ret.log = slt;
                }
            }
            catch (Exception e) {
                logger.error((Object)"Test failed", (Throwable)e);
                ret.error = new SerializedError((Throwable)e, true);
                if (!(e instanceof ExceptionWithLogTail)) break block39;
                ret.error.logTail = ((ExceptionWithLogTail)e).getLogTail();
            }
        }
        return ret;
    }

    @AuditInline
    @RequestMapping(value={"/api/agent-tools/create"}, method={RequestMethod.POST})
    @ResponseBody
    public Id create(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String proto) throws Exception {
        ProtoAgentTool protoObj = (ProtoAgentTool)JSON.parse((String)proto, ProtoAgentTool.class);
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            AuthCtx user = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(user, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.licenseEnforcementService.checkBasicLLMMeshAllowed(user);
            Id newId = new Id(this.crudService.create(user, projectKey, protoObj));
            t.commit("Created agent tool " + newId.id);
            this.auditTrailService.generic("agent-tool-create").with("projectKey", projectKey).with("agentToolId", newId.id).emit();
            Id id = newId;
            return id;
        }
    }

    @AuditInline
    @RequestMapping(value={"/api/agent-tools/create-from-kb"}, method={RequestMethod.POST})
    @ResponseBody
    public Id createFromKB(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String proto, @RequestParam String knowledgeBankRef) throws Exception {
        ProtoAgentTool protoObj = (ProtoAgentTool)JSON.parse((String)proto, ProtoAgentTool.class);
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            AuthCtx user = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(user, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.licenseEnforcementService.checkBasicLLMMeshAllowed(user);
            Id newId = new Id(this.crudService.createFromKB(user, projectKey, protoObj, knowledgeBankRef));
            t.commit("Created agent tool " + newId.id);
            this.auditTrailService.generic("agent-tool-create").with("projectKey", projectKey).with("agentToolId", newId.id).emit();
            Id id = newId;
            return id;
        }
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/agent-tools/check-save-conflict"})
    public void checkSaveConflict(HttpServletRequest req, HttpServletResponse resp, @RequestParam AgentTool agentTool) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, agentTool.projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            AgentTool existingAgentTool = this.crudService.getMandatoryUnsafe(agentTool.projectKey, agentTool.id);
            VersionTag.ConflictCheckResult ccr = this.conflictCheckService.checkConflict(existingAgentTool.versionTag, agentTool.versionTag);
            if (!ccr.canBeSaved) {
                ccr.message = "This agent tool is being edited by more than one user.";
            }
            AgentToolsController.writeJSON((HttpServletResponse)resp, (Object)ccr);
        }
    }

    @AuditInline
    @RequestMapping(value={"/api/agent-tools/save"}, method={RequestMethod.POST})
    @ResponseBody
    public AgentTool save(HttpServletRequest req, HttpServletResponse resp, @RequestParam AgentTool agentTool, @RequestParam(required=false, defaultValue="{}") TaggableObjectsService.TaggableObjectSaveInfo saveInfo) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(t.getUser(), agentTool.projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.licenseEnforcementService.checkBasicLLMMeshAllowed(t.getUser());
            this.crudService.save(agentTool, saveInfo.summaryOnly);
            t.commit("Saved agentTool " + agentTool.id);
            this.auditTrailService.generic("agent-tool-save").with("projectKey", agentTool.projectKey).with("agentToolId", agentTool.id).emit();
        }
        return agentTool;
    }

    @AuditInline
    @RequestMapping(value={"/api/agent-tools/copy"}, method={RequestMethod.POST})
    public void copy(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String agentToolId) throws Exception {
        Id id;
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            AuthCtx user = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.licenseEnforcementService.checkBasicLLMMeshAllowed(user);
            AgentTool agentTool = this.crudService.getMandatoryUnsafe(projectKey, agentToolId);
            id = new Id(this.crudService.copy(user, agentTool, projectKey));
            AgentToolsController.writeJSON((HttpServletResponse)resp, (Object)id);
            t.commit("Duplicated agent tool " + projectKey + "." + agentToolId + " to " + projectKey + "." + id.id);
        }
        this.auditTrailService.generic("agent-tool-duplicate").with("projectKey", projectKey).with("originalAgentToolId", agentToolId).with("newAgentToolId", id.id).emit();
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/agent-tools/list-email-templates"})
    public void listEmailTemplates(HttpServletRequest req, HttpServletResponse resp) throws IOException {
        try (Transaction t = this.transactionService.beginRead();){
            this.authService.getMandatoryUser(req);
        }
        TemplatesDAO templatesDAO = new TemplatesDAO();
        AgentToolsController.writeJSON((HttpServletResponse)resp, templatesDAO.list(TemplatesDAO.TemplateType.DIRECT_USAGE));
    }

    private static class AgentToolUsage {
        List<SavedModel> agents;

        private AgentToolUsage() {
        }
    }

    private static class ToolQuickTestResponse {
        SerializedError error;
        AgentToolRunner.AgentToolOutput response;
        LLMClient.LLMMeshTraceSpan fullTrace;
        LLMClient.LLMMeshTraceSpan traceOfPython;
        SmartLogTail log;

        private ToolQuickTestResponse() {
        }
    }

    public static class ProtoAgentTool {
        public String name;
        public String type;
        public String id;
        public JsonObject creationParams;
        public JsonObject quickTestQuery;
    }
}

