/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.controllers.analysis;

import com.dataiku.dip.analysis.coreservices.flow.SavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.langchain.PythonLLMServerKernelPool;
import com.dataiku.dip.llm.retrieval.RAGKernelPool;
import com.dataiku.dip.savedmodels.SavedModelsRetrievalAugmentedLlmService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.PermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditNotNeeded;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.controllers.NotFoundException;
import com.dataiku.dip.server.services.ConflictCheckService;
import com.dataiku.dip.server.services.LLMQuickChatService;
import com.dataiku.dip.server.services.LLMQuickTestService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.io.FilenameUtils;
import com.google.gson.reflect.TypeToken;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class SavedModelsRetrievalAugmentedLlmController
extends DIPInternalControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private PermissionsService permissionsService;
    @Autowired
    private UIAuthService authService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private SavedModelsRetrievalAugmentedLlmService augmentedLlmService;
    @Autowired
    private SavedModelsCRUDService service;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private PythonLLMServerKernelPool pythonLLMServerKernelPool;
    @Autowired
    private LLMQuickTestService llmQuickTestService;
    @Autowired
    private ConflictCheckService conflictCheckService;
    @Autowired
    private LLMQuickChatService llmQuickChatService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    @Autowired
    private RAGKernelPool ragKernelPool;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.rag.controller");

    @ResponseBody
    @AuditInline
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/create"}, method={RequestMethod.POST})
    public SavedModel createRetrievalAugmentedLlm(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String name, @RequestParam String knowledgeBankRef, @RequestParam String llmId) throws Exception {
        this.checkNotEmpty(new String[]{name, "Name of Saved Model can not be empty"});
        try {
            AuthCtx authCtx;
            try (Transaction ignored = this.transactionService.beginRead();){
                authCtx = this.authService.getMandatoryUser(req);
                this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            }
            SavedModel sm = this.augmentedLlmService.createRetrievalAugmentedLlmAndVersion(authCtx, projectKey, name, knowledgeBankRef, llmId);
            this.auditTrailService.generic("retrieval-augmented-llm-create").with("projectKey", projectKey).with("modelId", sm.id).with("type", SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM.toString()).emit();
            return sm;
        }
        catch (Exception e) {
            this.auditTrailService.failure("retrieval-augmented-llm-create", (Throwable)e).with("projectKey", projectKey).with("name", name).with("type", SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM.toString()).emit();
            throw e;
        }
    }

    @AuditInline
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/save"}, method={RequestMethod.POST})
    @ResponseBody
    public SavedModel retrievalAugmentedLlmSaveInline(HttpServletRequest req, HttpServletResponse resp, @RequestParam SavedModel savedModel) throws Exception {
        try (RWTransaction rwt = this.transactionService.beginWriteForUI(req);){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, savedModel.projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.service.getMandatory(savedModel.projectKey, savedModel.id);
            this.augmentedLlmService.save(authCtx, savedModel);
            rwt.commit("Saved Retrieval Augmented LLM: " + savedModel.id);
            this.auditTrailService.generic("retrieval-augmented-llm-save").with("projectKey", savedModel.projectKey).with("modelId", savedModel.id).emit();
        }
        catch (Exception e) {
            this.auditTrailService.failure("retrieval-augmented-llm-save", (Throwable)e).with("projectKey", savedModel.projectKey).with("modelId", savedModel.id).emit();
            throw e;
        }
        this.pythonLLMServerKernelPool.invalidateKernels(savedModel.projectKey, savedModel.id);
        return savedModel;
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-create-version", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/create-version"}, method={RequestMethod.POST})
    @ResponseBody
    public FullModelId createVersion(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String initialData) throws Exception {
        SavedModel sm;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            sm = (SavedModel)this.savedModelsDAO.getMandatory(projectKey, smId);
        }
        return this.augmentedLlmService.createVersion(authCtx, sm, versionId, (SavedModel.SavedModelInlineVersion)JSON.parse((String)initialData, SavedModel.SavedModelInlineVersion.class));
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-set-active", "projectKey", "${projectKey}", "modelId", "${savedModelId}", "version", "${newActiveVersion}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/set-active"})
    public void setActive(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String savedModelId, @RequestParam String newActiveVersion) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.service.getMandatory(projectKey, savedModelId);
            String oldActiveVersion = sm.getActiveVersion();
            this.augmentedLlmService.setVersionActive(sm, newActiveVersion);
            t.commit("Set active version of " + projectKey + "." + savedModelId + " to " + newActiveVersion + " (previously " + oldActiveVersion + ")");
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-delete-versions", "projectKey", "${projectKey}", "modelId", "${savedModelId}", "versions", "${versions}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/delete-versions"})
    public void deleteVersions(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String savedModelId, @RequestParam String versions) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.service.getMandatory(projectKey, savedModelId);
            ArrayList versionsList = (ArrayList)JSON.parse((String)versions, (TypeToken)new TypeToken<ArrayList<String>>(){});
            for (String version : versionsList) {
                this.augmentedLlmService.deleteVersion(sm, version);
            }
            t.commit("Deleted " + versionsList.size() + " versions from " + projectKey + "." + savedModelId);
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-duplicate-versions", "projectKey", "${projectKey}", "modelId", "${savedModelId}", "versionIdToCopy", "${versionIdToCopy}", "newVersionId", "${newVersionId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/duplicate-version"}, method={RequestMethod.POST})
    @ResponseBody
    public FullModelId duplicateVersion(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String savedModelId, @RequestParam String versionIdToCopy, @RequestParam String newVersionId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            SavedModel sm = this.service.getMandatory(projectKey, savedModelId);
            FullModelId fullModelId = this.augmentedLlmService.duplicateVersion(authCtx, sm, versionIdToCopy, newVersionId);
            t.commit("Duplicated " + versionIdToCopy + " to " + fullModelId.getSavedModelVersionID() + " in " + projectKey + "." + savedModelId);
            FullModelId fullModelId2 = fullModelId;
            return fullModelId2;
        }
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/check-save-conflict"})
    public void checkSaveConflict(HttpServletRequest req, HttpServletResponse resp, @RequestParam SavedModel savedModel) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, savedModel.projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel existingRALLM = this.service.getMandatory(savedModel.projectKey, savedModel.id);
            VersionTag.ConflictCheckResult ccr = this.conflictCheckService.checkConflict(existingRALLM.versionTag, savedModel.versionTag);
            if (!ccr.canBeSaved) {
                ccr.message = "This retrieval-augmented LLM is being edited by more than one user.";
            }
            SavedModelsRetrievalAugmentedLlmController.writeJSON((HttpServletResponse)resp, (Object)ccr);
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/list"}, method={RequestMethod.GET})
    @ResponseBody
    public List<SavedModel> listRetrievalAugmentedLlm(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            List<SavedModel> list = this.service.listUnsafe(projectKey, SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM);
            return list;
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-test", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/test"})
    @ResponseBody
    public LLMQuickTestService.LLMQuickTestResponse retrievalAugmentedLlmQuickTest(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String query) throws Exception {
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLMVersion(new AnyLoc(projectKey, smId).getFullName(), versionId);
        return this.llmQuickTestService.llmQuickTestCompletion(authCtx, projectKey, llmRef, query);
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-stop-dev-kernel", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/stop-dev-kernel"}, method={RequestMethod.POST})
    public void retrievalAugmentedLlmStopDevKernel(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId) throws Exception {
        LLMStructuredRef llmRef;
        String kbId;
        AuthCtx authCtx;
        try (RWTransaction rwt = this.transactionService.beginWriteForUI(req);){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.service.getMandatory(projectKey, smId);
            SavedModel.SavedModelInlineVersion smiv = sm.getVersion(sm.activeVersion).orElseThrow(() -> new IllegalArgumentException(String.format("Active version not found for retrieval-augmented LLM %s", sm.id)));
            kbId = smiv.ragllmSettings.kbRef;
            llmRef = LLMStructuredRef.forRetrievalAugmentedLLM(AnyLoc.resolveSmart(projectKey, smId).getFullName());
        }
        if (!this.ragKernelPool.stopDevKernel((DSSAuthCtx)authCtx, projectKey, kbId, llmRef)) {
            throw new NotFoundException("Dev kernel not running");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @AuditedCall(value={"msgType", "retrieval-augmented-llm-chat", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/chat"}, method={RequestMethod.POST})
    public void retrievalAugmentedLlmChat(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam LLMQuickChatService.LLMQuickChatInput chatInput) throws Exception {
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLMVersion(new AnyLoc(projectKey, smId).getFullName(), versionId);
        EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRefFromRetrievalAugmentedLLM(authCtx, projectKey, llmRef);
        try (MiniSSEEmitter emitter = new MiniSSEEmitter(resp, 0L, true);){
            LLMQuickChatService.LLMQuickChatResponse response = this.llmQuickChatService.streamChatResponse(authCtx, projectKey, enrichedLLMRef, chatInput, emitter);
            emitter.sendEventWithData("completion-response", JSON.json((Object)response), false);
        }
        catch (Exception e) {
            logger.error((Object)"Client disconnected", (Throwable)e);
        }
        finally {
            logger.info((Object)"streamed completion: done");
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-logs-list", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/list-logs"})
    public void retrievalAugmentedListLogs(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        SavedModelsRetrievalAugmentedLlmController.writeJSON((HttpServletResponse)resp, this.augmentedLlmService.listLogs(projectKey, smId, versionId));
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-log-get", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}", "logName", "${logName}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/stream-log"})
    public void retrievalAugmentedStreamLog(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String logName) throws Exception {
        String dlName;
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        String now = DKUtils.getDateFormatter((String)"_yyyy_MM_dd").print(new Date().getTime());
        if (StringUtils.isNotBlank((String)logName)) {
            String basename = FilenameUtils.getBaseName((String)logName);
            String extension = FilenameUtils.getExtension((String)logName);
            dlName = basename + now + "." + extension + ".gz";
        } else {
            dlName = "dssLogs" + now + ".zip";
        }
        String cd = String.format("attachment; filename=\"%s\"", dlName);
        resp.setHeader("Content-Disposition", cd);
        this.augmentedLlmService.streamLog(projectKey, smId, versionId, logName, (OutputStream)resp.getOutputStream());
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-log-get", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}", "logName", "${logName}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/get-log"})
    public void retrievalAugmentedGetLog(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String logName) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        SavedModelsRetrievalAugmentedLlmController.writeJSON((HttpServletResponse)resp, (Object)this.augmentedLlmService.getLog(projectKey, smId, versionId, logName));
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-log-delete", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}", "logName", "${logName}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/delete-log"}, method={RequestMethod.POST})
    public void agentDeleteLog(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String logName) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        this.augmentedLlmService.deleteLog(projectKey, smId, versionId, logName);
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-logs-clear", "projectKey", "${projectKey}", "modelId", "${smId}", "versionId", "${versionId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/clear-logs"}, method={RequestMethod.POST})
    public void retrievalAugmentedClearLogs(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        this.augmentedLlmService.clearLogs(projectKey, smId, versionId);
    }
}

