/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.intercom.t;

import com.dataiku.dip.analysis.coreservices.flow.SavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.SMStatus;
import com.dataiku.dip.analysis.ml.SMVersionHeader;
import com.dataiku.dip.analysis.ml.clustering.flow.ClusteringSMMgmtService;
import com.dataiku.dip.analysis.ml.llm.LLMSMMgmtService;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionSMMgmtService;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.exec.CodeBasedRecipeDatasetInfoHelper;
import com.dataiku.dip.savedmodels.SavedModelsService;
import com.dataiku.dip.security.IPermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.AuthCtxUsage;
import com.dataiku.dip.security.auth.MetaAuthService;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.collect.Lists;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
public class SavedModelsIntercomController
extends DIPInternalControllerBase {
    @Autowired
    private MetaAuthService authService;
    @Autowired
    private ClusteringSMMgmtService csmmService;
    @Autowired
    private IPermissionsService permissionsService;
    @Autowired
    private PredictionSMMgmtService psmmService;
    @Autowired
    private LLMSMMgmtService llmsmService;
    @Autowired
    private SavedModelsCRUDService smcService;
    @Autowired
    private SavedModelsService savedModelsService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private SavedModelsDAO smDao;
    @Autowired
    private PubSubService pubSub;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.analysis");

    @AuditedCall(value={"msgType", "savedmodel-set-active", "projectKey", "${projectKey}", "modelId", "${smId}", "version", "${versionId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/set-active"})
    public void setActive(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            this.permissionsService.checkProjectPrivileges(t.getUser(), projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.smcService.getMandatory(projectKey, smId);
            boolean schemaChanged = false;
            switch (sm.getType()) {
                case CLUSTERING: {
                    schemaChanged = this.csmmService.setActive(sm, versionId);
                    break;
                }
                case PREDICTION: {
                    schemaChanged = this.psmmService.setActive(sm, versionId);
                    break;
                }
                case LLM_GENERIC_RAW: 
                case LLM_GENERIC_PROMPTABLE_COMPLETION: 
                case LLM_CLASSIFICATION: {
                    if (sm.savedModelType == SavedModel.SavedModelType.PYTHON_AGENT) {
                        this.llmsmService.setAgentVersionActive(sm, versionId);
                    }
                    if (sm.savedModelType != SavedModel.SavedModelType.LLM_GENERIC) break;
                    this.llmsmService.setLLMGenericVersionActive(sm, versionId);
                }
            }
            if (schemaChanged) {
                logger.info((Object)"The active version does not have the same preparation script schema as the previous one");
            }
            t.commit("Set active version of " + projectKey + "." + smId + " to " + versionId);
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-increment-last-train", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/increment-last-train"})
    public void increment(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String jobId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            this.permissionsService.checkProjectPrivileges(t.getUser(), projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.smcService.getMandatory(projectKey, smId);
            sm.incrementLastTrainIndex(jobId);
            this.smcService.save(sm, false, false);
            t.commit("Increment last train index of " + projectKey + "." + smId + " to " + sm.lastTrainIndex);
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-get-model-details", "projectKey", "${projectKey}", "modelId", "${smId}", "version", "${versionId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/get-model-details"})
    public void getModelDetails(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();
             AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);){
            this.permissionsService.checkProjectPrivileges(authCtxUsage.getAuthCtx(), projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            SMDetails SMDetails2 = new SMDetails();
            FullModelId fmi = new FullModelId(projectKey, smId, versionId);
            SMDetails2.model_folder = fmi.getModelFolder().getAbsolutePath();
            SMDetails2.saved_model = (SavedModel)this.smDao.getMandatory(projectKey, smId);
            SMDetails2.fmi = fmi.toString();
            SavedModelsIntercomController.writeJSON((HttpServletResponse)resp, (Object)SMDetails2);
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-get", "projectKey", "${projectKey}", "modelId", "${savedModelId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/get"})
    public void getModel(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String savedModelId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();
             AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);){
            this.permissionsService.checkProjectPrivileges(authCtxUsage.getAuthCtx(), projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            SavedModel savedModel = this.smcService.getMandatory(projectKey, savedModelId);
            savedModel.miniTask = null;
            savedModel.conditionalOutputs = null;
            SavedModelsIntercomController.writeJSON((HttpServletResponse)resp, (Object)savedModel);
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-get-versions", "projectKey", "${projectKey}", "modelId", "${savedModelId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/list-versions"})
    @ResponseBody
    public List<? extends SMVersionHeader<?>> listVersions(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String savedModelId) throws Exception {
        SavedModel sm;
        try (Transaction t = this.transactionService.beginRead();
             AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);){
            this.permissionsService.checkProjectPrivileges(authCtxUsage.getAuthCtx(), projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            sm = this.smcService.getMandatory(projectKey, savedModelId);
        }
        SMStatus smStatus = switch (sm.getType()) {
            case MLTask.MLTaskType.PREDICTION -> this.psmmService.getStatus_NT(sm);
            case MLTask.MLTaskType.CLUSTERING -> this.csmmService.getStatus_NT(sm);
            case MLTask.MLTaskType.LLM_GENERIC_RAW, MLTask.MLTaskType.LLM_GENERIC_PROMPTABLE_COMPLETION, MLTask.MLTaskType.LLM_CLASSIFICATION -> LLMSMMgmtService.getStatus_NT(sm);
            default -> throw new IllegalArgumentException(String.format("Unknown type %s", new Object[]{sm.getType()}));
        };
        return smStatus.versions;
    }

    @AuditedCall(value={"msgType", "savedmodel-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/list"})
    public void listModels(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();
             AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);){
            this.permissionsService.checkProjectPrivileges(authCtxUsage.getAuthCtx(), projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            List<SavedModel> savedModels = this.smcService.list(projectKey);
            ArrayList infos = Lists.newArrayList();
            for (SavedModel savedModel : savedModels) {
                infos.add(new SavedModel.SavedModelListItem(savedModel));
            }
            SavedModelsIntercomController.writeJSON((HttpServletResponse)resp, (Object)infos);
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-get", "projectKey", "${projectKey}", "modelId", "${lookup}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/get-info"})
    public void getInfo(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String lookup) throws Exception {
        CodeBasedRecipeDatasetInfoHelper.LocationInfo info;
        try (Transaction t = this.transactionService.beginRead();
             AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);){
            this.permissionsService.checkProjectPrivileges(authCtxUsage.getAuthCtx(), projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            SavedModel sm = this.smcService.lookupMandatoryUnsafe(projectKey, lookup);
            logger.infoV("Searching for model %s/%s -> %s", new Object[]{projectKey, lookup, sm});
            info = CodeBasedRecipeDatasetInfoHelper.LocationInfo.makeInfo(CodeBasedRecipeDatasetInfoHelper.LocationInfoType.SAVEDMODEL, new Object[]{"projectKey", sm.projectKey, "id", sm.id, "name", sm.name, "type", sm.getType()});
            this.auditTrailService.generic("savedmodel-read-meta").with("projectKey", sm.projectKey).with("modelId", sm.id).emit();
        }
        SavedModelsIntercomController.writeJSON((HttpServletResponse)resp, (Object)info);
    }

    @AuditedCall(value={"msgType", "savedmodel-create-finetuned-llm-version", "projectKey", "${projectKey}", "modelId", "${smId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/llm-generic/create-finetuned-version"}, method={RequestMethod.POST})
    @ResponseBody
    public FullModelId llmGenericCreateFinetunedLLMVersion(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String smId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            FullModelId fullModelId;
            block12: {
                AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);
                try {
                    this.permissionsService.checkProjectPrivileges(t.getUser(), projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
                    SavedModel sm = this.smcService.getMandatory(projectKey, smId);
                    FullModelId fmi = this.savedModelsService.createFinetunedLLMVersion(authCtxUsage.getAuthCtx(), sm);
                    t.commit("Fine-tuned LLM version " + fmi.getSavedModelVersionID() + " created for model " + projectKey + "." + smId);
                    fullModelId = fmi;
                    if (authCtxUsage == null) break block12;
                }
                catch (Throwable throwable) {
                    if (authCtxUsage != null) {
                        try {
                            authCtxUsage.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                authCtxUsage.close();
            }
            return fullModelId;
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-save-finetuned-llm-version", "projectKey", "${projectKey}", "modelId", "${smId}", "version", "${versionId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/llm-generic/save-finetuned-version"}, method={RequestMethod.POST})
    @ResponseBody
    public FullModelId llmGenericSaveFinetunedLLMVersion(HttpServletRequest req, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId, @RequestParam String connectionName, @RequestParam(required=false, defaultValue="{}") SavedModelsService.FinetuningConfig finetuningConfig) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            FullModelId fullModelId;
            block12: {
                AuthCtxUsage authCtxUsage = this.authService.getTicketOrKeyAndContext(req);
                try {
                    this.permissionsService.checkProjectPrivileges(t.getUser(), projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
                    SavedModel sm = this.smcService.getMandatory(projectKey, smId);
                    FullModelId fmi = this.savedModelsService.saveFinetunedLLMVersion(authCtxUsage.getAuthCtx(), sm, versionId, connectionName, finetuningConfig);
                    t.commit("Fine-tuned LLM version " + versionId + " saved in " + projectKey + "." + smId + " on connection " + connectionName);
                    fullModelId = fmi;
                    if (authCtxUsage == null) break block12;
                }
                catch (Throwable throwable) {
                    if (authCtxUsage != null) {
                        try {
                            authCtxUsage.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                authCtxUsage.close();
            }
            return fullModelId;
        }
    }

    @AuditedCall(value={"msgType", "savedmodel-delete-finetuned-llm-version", "projectKey", "${projectKey}", "modelId", "${smId}", "version", "${versionId}"})
    @RequestMapping(value={"/api/tintercom/savedmodels/llm-generic/delete-finetuned-version"}, method={RequestMethod.POST})
    public void llmGenericDeleteFinetunedLLMVersion(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String smId, @RequestParam String versionId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            this.permissionsService.checkProjectPrivileges(t.getUser(), projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            SavedModel sm = this.smcService.getMandatory(projectKey, smId);
            this.savedModelsService.deleteFinetunedLLMVersion(sm, versionId);
            t.commit("Fine-tuned LLM version " + versionId + " deleted for model " + projectKey + "." + smId);
        }
    }

    public static class SMDetails {
        public String model_folder;
        public SavedModel saved_model;
        public String fmi;
    }
}

