/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.llm;

import com.dataiku.dip.analysis.coreservices.flow.ISavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.SMMgmtService;
import com.dataiku.dip.analysis.ml.SMStatus;
import com.dataiku.dip.analysis.ml.SMVersionHeader;
import com.dataiku.dip.analysis.ml.llm.LLMModelReader;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.SavedModelOriginInfo;
import com.dataiku.dip.analysis.model.llm.LLMModelDetails;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.llm.savedmodels.SavedModelVersionDeploymentCRUDService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.ModelVersionDeletedEvent;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.j2ts.annotations.UIModel;
import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMSMMgmtService
extends SMMgmtService {
    @Autowired
    private ISavedModelsCRUDService crudService;
    @Autowired
    private SavedModelVersionDeploymentCRUDService smvDeploymentCRUDService;
    @Autowired
    private PubSubService pubSub;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.models.llm");

    @Override
    public List<FullModelId> listUsableVersions(SavedModel sm) {
        return LLMSMMgmtService.staticListUsableVersions(sm);
    }

    private static List<FullModelId> staticListUsableVersions(SavedModel sm) {
        switch (sm.savedModelType.savedModelHandlingType) {
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: 
            case RETRIEVAL_AUGMENTED_LLM: {
                return sm.inlineVersions.stream().map(v -> new FullModelId(sm.projectKey, sm.id, v.versionId)).toList();
            }
            case LLM_GENERIC: {
                return LLMSMMgmtService.listVersionFolders(sm).stream().map(f -> new FullModelId(sm.projectKey, sm.getId(), f.getName())).filter(FullModelId::isModelUsable).toList();
            }
            case INTERNAL: 
            case EXTERNAL_MLFLOW: {
                throw new IllegalArgumentException("unreachable");
            }
        }
        return Collections.emptyList();
    }

    public static LLMSMStatus getStatus_NT(SavedModel sm) {
        LLMSMStatus ret = new LLMSMStatus();
        ret.activeVersionId = sm.activeVersion;
        switch (sm.savedModelType.savedModelHandlingType) {
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: 
            case RETRIEVAL_AUGMENTED_LLM: {
                for (SavedModel.SavedModelInlineVersion version : sm.inlineVersions) {
                    LLMSMVersionHeader vh = new LLMSMVersionHeader();
                    vh.versionId = version.versionId;
                    vh.active = version.versionId.equals(sm.activeVersion);
                    vh.snippet = new LLMModelSnippetData();
                    ((LLMModelSnippetData)vh.snippet).fullModelId = new FullModelId(sm.projectKey, sm.id, version.versionId).toString();
                    ((LLMModelSnippetData)vh.snippet).creationTag = version.creationTag;
                    ((LLMModelSnippetData)vh.snippet).versionTag = version.versionTag;
                    ((LLMModelSnippetData)vh.snippet).description = version.description;
                    ((LLMModelSnippetData)vh.snippet).trainInfo = new ModelTrainInfo();
                    ((LLMModelSnippetData)vh.snippet).trainInfo.state = ModelTrainInfo.ModelTrainState.DONE;
                    ((LLMModelSnippetData)vh.snippet).userMeta = new ModelUserMeta();
                    ((LLMModelSnippetData)vh.snippet).userMeta.name = version.versionId;
                    ((LLMModelSnippetData)vh.snippet).savedModelType = sm.savedModelType;
                    ret.versions.add(vh);
                }
                break;
            }
            case LLM_GENERIC: {
                for (FullModelId fmi : LLMSMMgmtService.staticListUsableVersions(sm)) {
                    try {
                        LLMSMVersionHeader vh = new LLMSMVersionHeader();
                        vh.versionId = fmi.getSavedModelVersionID();
                        vh.active = vh.versionId.equals(ret.activeVersionId);
                        vh.snippet = LLMModelReader.makeSnippet(fmi);
                        ((LLMModelSnippetData)vh.snippet).sessionDate = ((LLMModelSnippetData)vh.snippet).trainInfo == null ? 0L : ((LLMModelSnippetData)vh.snippet).trainInfo.endTime;
                        ((LLMModelSnippetData)vh.snippet).savedModelType = SavedModel.SavedModelType.LLM_GENERIC;
                        ret.versions.add(vh);
                    }
                    catch (Exception e) {
                        logger.warn((Object)"Failed to create model snippet", (Throwable)e);
                    }
                }
                break;
            }
            case INTERNAL: 
            case EXTERNAL_MLFLOW: {
                throw new IllegalArgumentException("unreachable");
            }
        }
        return ret;
    }

    public void setLLMGenericVersionActive(SavedModel sm, String versionId) throws Exception {
        boolean changed;
        boolean bl = changed = !StringUtils.equals((String)sm.activeVersion, (String)versionId);
        if (changed) {
            FullModelId newFMI = new FullModelId(sm.projectKey, sm.id, versionId);
            if (!newFMI.isModelUsable()) {
                throw ErrorContext.iaef((String)"The new active version to set (%s) is not valid", (Object)versionId, (Object[])new Object[0]);
            }
            sm.activeVersion = versionId;
            this.crudService.save(sm, false, false);
        }
    }

    public void deleteAllLLMGenericModelVersions(AuthCtx authCtx, SavedModel sm) throws Exception {
        assert (sm.savedModelType == SavedModel.SavedModelType.LLM_GENERIC);
        for (FullModelId version : this.listUsableVersions(sm)) {
            this.deleteLLMGenericModelVersion(authCtx, sm, version.getSavedModelVersionID(), true);
        }
    }

    public void deleteLLMGenericModelVersion(AuthCtx authCtx, SavedModel sm, String version, Boolean allowDeleteActive) throws Exception {
        assert (sm.savedModelType == SavedModel.SavedModelType.LLM_GENERIC);
        if (StringUtils.equals((String)sm.activeVersion, (String)version) && !allowDeleteActive.booleanValue()) {
            throw new IllegalArgumentException("Cannot delete the active version");
        }
        File versionFolder = MLPaths.savedModelVersionFolder(sm, version);
        if (!versionFolder.isDirectory()) {
            throw new IllegalArgumentException("Can't find model version: " + version);
        }
        FullModelId fmi = new FullModelId(sm.projectKey, sm.id, version);
        LLMModelDetails llmModelDetails = LLMModelDetails.fromFullModelId(fmi, false);
        if (llmModelDetails.deployment != null && llmModelDetails.deployment.deploymentId != null) {
            this.smvDeploymentCRUDService.deleteDeployment(authCtx, sm.projectKey, sm, fmi);
        }
        if (llmModelDetails.llmSMInfo.remoteModelId != null && llmModelDetails.llmSMInfo.remoteJobId != null) {
            LLMStructuredRef llmRef = LLMStructuredRef.forFinetunedSavedModelVersion(llmModelDetails.llmSMInfo.llmType, llmModelDetails.llmSMInfo.connection, sm.id, version);
            LLMClient client = LLMClientFactory.get(authCtx, sm.projectKey, llmRef);
            RemoteFineTuningClient ftClient = client.newFineTuningClient();
            ftClient.deleteFinetunedModel(llmModelDetails.llmSMInfo.remoteModelId, llmModelDetails.llmSMInfo.remoteJobId);
        }
        this.pubSub.publishAfterTransaction((DSSEvent)new ModelVersionDeletedEvent(new FullModelId(sm.getProjectKey(), sm.getId(), version).toString()));
        DKUFileUtils.forceDelete((File)versionFolder);
    }

    public void setAgentVersionActive(SavedModel sm, String versionId) throws Exception {
        boolean changed;
        boolean bl = changed = sm.activeVersion == null || !StringUtils.equals((String)sm.activeVersion, (String)versionId);
        if (changed) {
            if (StringUtils.isBlank((String)versionId)) {
                throw ErrorContext.iaef((String)"The new active version to set (%s) is not valid", (Object)versionId, (Object[])new Object[0]);
            }
            if (sm.getVersion(versionId).isEmpty()) {
                throw new IllegalArgumentException("Can't find model version: " + versionId);
            }
            sm.activeVersion = versionId;
            this.crudService.save(sm, false, false);
        }
    }

    public void deleteAgentVersion(SavedModel sm, String versionId) throws Exception {
        if (StringUtils.equals((String)sm.activeVersion, (String)versionId)) {
            throw new IllegalArgumentException("Cannot delete the active version");
        }
        int nbVersions = sm.inlineVersions.size();
        List filteredVersions = sm.inlineVersions.stream().filter(version -> versionId != null && !versionId.equals(version.versionId)).collect(Collectors.toList());
        if (filteredVersions.size() == nbVersions) {
            throw new IllegalArgumentException("Can't find model version: " + versionId);
        }
        sm.inlineVersions = filteredVersions;
        this.crudService.save(sm, false, false);
        this.pubSub.publishAfterTransaction((DSSEvent)new ModelVersionDeletedEvent(new FullModelId(sm.getProjectKey(), sm.getId(), versionId).toString()));
    }

    public FullModelId duplicateAgentVersion(AuthCtx authCtx, SavedModel sm, String versionIdToCopy, String newVersionId) throws Exception {
        SavedModel.SavedModelInlineVersion targetVersion = sm.getVersion(versionIdToCopy).orElseThrow(() -> new IllegalArgumentException("Can't find model version: " + versionIdToCopy));
        StringTransmogrifier transmogrifier = new StringTransmogrifier("_");
        List usedIds = sm.inlineVersions.stream().map(version -> version.versionId).collect(Collectors.toList());
        transmogrifier.addAllAlreadyTransmogrifiedAcceptDupes(usedIds);
        newVersionId = transmogrifier.transmogrify(newVersionId);
        SavedModel.SavedModelInlineVersion duplicatedVersion = new SavedModel.SavedModelInlineVersion(targetVersion);
        duplicatedVersion.versionId = newVersionId;
        duplicatedVersion.versionTag = duplicatedVersion.creationTag = new VersionTag(authCtx.getIdentifier());
        sm.inlineVersions.add(duplicatedVersion);
        this.crudService.save(sm, true, false);
        return new FullModelId(sm.projectKey, sm.id, newVersionId);
    }

    @UIModel
    public static class LLMSMStatus
    extends SMStatus<LLMSMVersionHeader> {
    }

    public static class LLMSMVersionHeader
    extends SMVersionHeader<LLMModelSnippetData> {
        public SavedModelOriginInfo smOrigin;
    }
}

