/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.controllers.analysis;

import com.dataiku.dip.analysis.coreservices.flow.SavedModelsCRUDService;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.langchain.PythonLLMServerKernelPool;
import com.dataiku.dip.llm.retrieval.RAGKernelPool;
import com.dataiku.dip.savedmodels.SavedModelsRetrievalAugmentedLlmService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.PermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditNotNeeded;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.controllers.NotFoundException;
import com.dataiku.dip.server.services.ConflictCheckService;
import com.dataiku.dip.server.services.LLMQuickChatService;
import com.dataiku.dip.server.services.LLMQuickTestService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class SavedModelsRetrievalAugmentedLlmController
extends DIPInternalControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private PermissionsService permissionsService;
    @Autowired
    private UIAuthService authService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private SavedModelsRetrievalAugmentedLlmService augmentedLlmService;
    @Autowired
    private SavedModelsCRUDService service;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private PythonLLMServerKernelPool pythonLLMServerKernelPool;
    @Autowired
    private LLMQuickTestService llmQuickTestService;
    @Autowired
    private ConflictCheckService conflictCheckService;
    @Autowired
    private LLMQuickChatService llmQuickChatService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    @Autowired
    private RAGKernelPool ragKernelPool;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.rag.controller");

    @ResponseBody
    @AuditInline
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/create"}, method={RequestMethod.POST})
    public SavedModel createRetrievalAugmentedLlm(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String name, @RequestParam String knowledgeBankRef, @RequestParam String llmId) throws Exception {
        this.checkNotEmpty(new String[]{name, "Name of Saved Model can not be empty"});
        try {
            AuthCtx authCtx;
            try (Transaction ignored = this.transactionService.beginRead();){
                authCtx = this.authService.getMandatoryUser(req);
                this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            }
            SavedModel sm = this.augmentedLlmService.createRetrievalAugmentedLlmAndVersion(authCtx, projectKey, name, knowledgeBankRef, llmId);
            this.auditTrailService.generic("retrieval-augmented-llm-create").with("projectKey", projectKey).with("modelId", sm.id).with("type", SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM.toString()).emit();
            return sm;
        }
        catch (Exception e) {
            this.auditTrailService.failure("retrieval-augmented-llm-create", (Throwable)e).with("projectKey", projectKey).with("name", name).with("type", SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM.toString()).emit();
            throw e;
        }
    }

    @AuditInline
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/save"}, method={RequestMethod.POST})
    @ResponseBody
    public SavedModel retrievalAugmentedLlmSaveInline(HttpServletRequest req, HttpServletResponse resp, @RequestParam SavedModel savedModel) throws Exception {
        try (RWTransaction rwt = this.transactionService.beginWriteForUI(req);){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, savedModel.projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.service.getMandatory(savedModel.projectKey, savedModel.id);
            this.augmentedLlmService.save(authCtx, savedModel);
            rwt.commit("Saved Retrieval Augmented LLM: " + savedModel.id);
            this.auditTrailService.generic("retrieval-augmented-llm-save").with("projectKey", savedModel.projectKey).with("modelId", savedModel.id).emit();
        }
        catch (Exception e) {
            this.auditTrailService.failure("retrieval-augmented-llm-save", (Throwable)e).with("projectKey", savedModel.projectKey).with("modelId", savedModel.id).emit();
            throw e;
        }
        this.pythonLLMServerKernelPool.invalidateKernels(savedModel.projectKey, savedModel.id);
        return savedModel;
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/check-save-conflict"})
    public void checkSaveConflict(HttpServletRequest req, HttpServletResponse resp, @RequestParam SavedModel savedModel) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, savedModel.projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel existingRALLM = this.service.getMandatory(savedModel.projectKey, savedModel.id);
            VersionTag.ConflictCheckResult ccr = this.conflictCheckService.checkConflict(existingRALLM.versionTag, savedModel.versionTag);
            if (!ccr.canBeSaved) {
                ccr.message = "This retrieval-augmented LLM is being edited by more than one user.";
            }
            SavedModelsRetrievalAugmentedLlmController.writeJSON((HttpServletResponse)resp, (Object)ccr);
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/list"}, method={RequestMethod.GET})
    @ResponseBody
    public List<SavedModel> listRetrievalAugmentedLlm(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            List<SavedModel> list = this.service.listUnsafe(projectKey, SavedModel.SavedModelType.RETRIEVAL_AUGMENTED_LLM);
            return list;
        }
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-test", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/test"})
    @ResponseBody
    public LLMQuickTestService.LLMQuickTestResponse retrievalAugmentedLlmQuickTest(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String query) throws Exception {
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLMVersion(new AnyLoc(projectKey, smId).getFullName(), versionId);
        return this.llmQuickTestService.llmQuickTestCompletion(authCtx, projectKey, llmRef, query);
    }

    @AuditedCall(value={"msgType", "retrieval-augmented-llm-stop-dev-kernel", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/stop-dev-kernel"}, method={RequestMethod.POST})
    public void retrievalAugmentedLlmStopDevKernel(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId) throws Exception {
        String kbId;
        AuthCtx authCtx;
        try (RWTransaction rwt = this.transactionService.beginWriteForUI(req);){
            authCtx = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.service.getMandatory(projectKey, smId);
            SavedModel.SavedModelInlineVersion smiv = sm.getVersion(sm.activeVersion).orElseThrow(() -> new IllegalArgumentException(String.format("Active version not found for retrieval-augmented LLM %s", sm.id)));
            kbId = smiv.ragllmSettings.kbRef;
        }
        if (!this.ragKernelPool.stopDevKernel((DSSAuthCtx)authCtx, projectKey, kbId)) {
            throw new NotFoundException("Dev kernel not running");
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @AuditedCall(value={"msgType", "retrieval-augmented-llm-chat", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/savedmodels/retrieval-augmented-llm/chat"}, method={RequestMethod.POST})
    public void retrievalAugmentedLlmChat(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam LLMQuickChatService.LLMQuickChatInput chatInput) throws Exception {
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        LLMStructuredRef llmRef = LLMStructuredRef.forRetrievalAugmentedLLMVersion(new AnyLoc(projectKey, smId).getFullName(), versionId);
        EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRefFromRetrievalAugmentedLLM(authCtx, projectKey, llmRef);
        try (MiniSSEEmitter emitter = new MiniSSEEmitter(resp, 0L, true);){
            LLMQuickChatService.LLMQuickChatResponse response = this.llmQuickChatService.streamChatResponse(authCtx, projectKey, enrichedLLMRef, chatInput, emitter);
            emitter.sendEventWithData("completion-response", JSON.json((Object)response), false);
        }
        catch (Exception e) {
            logger.error((Object)"Client disconnected", (Throwable)e);
        }
        finally {
            logger.info((Object)"streamed completion: done");
        }
    }
}

