/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.dao;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.ConditionalOutput;
import com.dataiku.dip.coremodel.Partitionable;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.dao.SavedModelHandler;
import com.dataiku.dip.datasets.PartitionableHandler;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.metrics.ChecksSet;
import com.dataiku.dip.metrics.ProbesSet;
import com.dataiku.dip.partitioning.PartitioningScheme;
import com.dataiku.dip.projects.importexport.RetrievalAugmentedLLMConnectionsUtils;
import com.dataiku.dip.savedmodels.agents.AgentTypesRegistry;
import com.dataiku.dip.savedmodels.agents.CustomAgentDependenciesCollector;
import com.dataiku.dip.savedmodels.agents.CustomAgentMeta;
import com.dataiku.dip.savedmodels.proxymodels.ProxyModelConfiguration;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.util.JsonUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.j2ts.annotations.UIModel;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;

@UIModel
public class SavedModel
extends TaggableObjectsService.TaggableObject
implements Partitionable {
    public boolean needsInputDataFolder = false;
    @JSON.FileTransient
    public String projectKey;
    @JSON.FileTransient
    public String id;
    public SavedModelType savedModelType = SavedModelType.DSS_MANAGED;
    public ModelPublishPolicy publishPolicy = ModelPublishPolicy.UNCONDITIONAL;
    public SerializedDataset.RebuildBehavior rebuildBehavior = SerializedDataset.RebuildBehavior.EXPLICIT;
    public boolean cleanTemporaryVersionsPostJob;
    public String name;
    public String contentType;
    public static String llmFinetuningContentType = "llm/fine-tuning";
    public String activeVersion;
    @Nullable
    public ProxyModelConfiguration proxyModelConfiguration;
    public List<SavedModelInlineVersion> inlineVersions = new ArrayList<SavedModelInlineVersion>();
    public List<ConditionalOutput> conditionalOutputs = new ArrayList<ConditionalOutput>();
    public PartitioningScheme partitioning = new PartitioningScheme();
    public SerializedDataset.FlowOptions flowOptions = new SerializedDataset.FlowOptions();
    public String lastExportedFrom;
    public ProbesSet metrics = new ProbesSet();
    public ChecksSet metricsChecks = new ChecksSet();
    public MLTask miniTask;
    public long lastTrainIndex = 1L;
    public String lastTrainJobId = "";

    public SavedModel() {
    }

    public SavedModel(String projectKey, String id) {
        this.projectKey = projectKey;
        this.id = id;
    }

    @Override
    public ITaggingService.TaggableType getTaggableType() {
        return ITaggingService.TaggableType.SAVED_MODEL;
    }

    @Override
    public String getSubtype() {
        return this.getType().toString();
    }

    @Override
    public PartitioningScheme getPartitioningSchema() {
        return this.partitioning;
    }

    @Override
    public SerializedDataset.FlowOptions getFlowOptions() {
        return this.flowOptions;
    }

    @Override
    public String getProjectKey() {
        return this.projectKey;
    }

    @Override
    public String getName() {
        return this.id;
    }

    @Override
    public String getFullName() {
        return this.getFullId();
    }

    @Override
    public PartitionableHandler buildHandler(AuthCtx authCtx) {
        return new SavedModelHandler(this);
    }

    @Override
    public String getId() {
        return this.id;
    }

    public String getActiveVersion() {
        return this.activeVersion;
    }

    public SavedModelInlineVersion getActiveSaveModelInlineVersion() {
        return this.inlineVersions.stream().filter(inlineVersion -> this.activeVersion.equals(inlineVersion.versionId)).findAny().orElse(null);
    }

    @Override
    public String getDisplayName() {
        return this.name;
    }

    @Override
    public void setProjectKey(String projectKey) {
        this.projectKey = projectKey;
    }

    @Override
    public void setId(String id) {
        this.id = id;
    }

    public MLTask.MLTaskType getType() {
        if (this.savedModelType == null) {
            Logger.getLogger((String)"dku").warn((Object)("Bad SM type in SM " + this.id));
            throw new NullPointerException("Bad SM type in SM " + this.id);
        }
        switch (this.savedModelType.savedModelHandlingType) {
            case INTERNAL: {
                return this.miniTask.taskType;
            }
            case EXTERNAL_MLFLOW: {
                return MLTask.MLTaskType.PREDICTION;
            }
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: 
            case LLM_GENERIC: 
            case RETRIEVAL_AUGMENTED_LLM: {
                return MLTask.MLTaskType.LLM_GENERIC_RAW;
            }
        }
        throw new Error("unreachable");
    }

    public boolean isPartitioned() {
        return this.partitioning != null && this.partitioning.getDimensionNames().size() > 0;
    }

    public void incrementLastTrainIndex(String jobId) {
        if (!StringUtils.equals((String)this.lastTrainJobId, (String)jobId)) {
            ++this.lastTrainIndex;
            this.lastTrainJobId = jobId;
        }
    }

    public String getNextVersionSuffix() {
        return " - v" + (this.lastTrainIndex + 1L);
    }

    public String getInitialVersionSuffix() {
        return " - v1";
    }

    public String getMLCategory() {
        switch (this.savedModelType) {
            case MLFLOW_PYFUNC: {
                return "mlflow";
            }
            case LLM_GENERIC: {
                return "fine_tuning";
            }
            case PYTHON_AGENT: {
                return "python_agent";
            }
            case PLUGIN_AGENT: {
                return "plugin_agent";
            }
            case TOOLS_USING_AGENT: {
                return "tools_using_agent";
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                return "retrieval_augmented_llm";
            }
            case PROXY_MODEL: {
                if (this.proxyModelConfiguration == null) break;
                switch (this.proxyModelConfiguration.protocol) {
                    case "sagemaker": {
                        return "sagemaker";
                    }
                    case "vertex-ai": {
                        return "vertex";
                    }
                    case "azure-ml": {
                        return "azure-ml";
                    }
                    case "databricks": {
                        return "databricks";
                    }
                }
            }
        }
        switch (this.getType()) {
            case CLUSTERING: {
                return "clustering";
            }
            case PREDICTION: {
                if (this.miniTask.backendType == MLTask.BackendType.KERAS) {
                    return "keras";
                }
                switch (((PredictionMLTask)this.miniTask).predictionType) {
                    case TIMESERIES_FORECAST: {
                        return "timeseries";
                    }
                    case CAUSAL_REGRESSION: 
                    case CAUSAL_BINARY_CLASSIFICATION: {
                        return "causal";
                    }
                    case DEEP_HUB_IMAGE_CLASSIFICATION: {
                        return "deephub_image_classification";
                    }
                    case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                        return "deephub_object_detection";
                    }
                }
            }
        }
        return "prediction";
    }

    public String getObjectIconPath() {
        switch (this.savedModelType) {
            case MLFLOW_PYFUNC: {
                return "ml/dku-icon-flow-colored-model-mlflow.png";
            }
            case LLM_GENERIC: {
                return "llm/dku-icon-flow-colored-finetuned-savedmodel.png";
            }
            case PYTHON_AGENT: {
                return "llm/dku-icon-flow-colored-agent-code.png";
            }
            case PLUGIN_AGENT: {
                return "llm/dku-icon-flow-colored-agent-plugin.png";
            }
            case TOOLS_USING_AGENT: {
                return "llm/dku-icon-flow-colored-agent-visual.png";
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                return "llm/dku-icon-flow-colored-retrieval-augmented.png";
            }
            case PROXY_MODEL: {
                if (this.proxyModelConfiguration == null) break;
                switch (this.proxyModelConfiguration.protocol) {
                    case "sagemaker": {
                        return "ml/dku-icon-flow-colored-model-external-sagemaker.png";
                    }
                    case "vertex-ai": {
                        return "ml/dku-icon-flow-colored-model-external-vertex.png";
                    }
                    case "azure-ml": {
                        return "ml/dku-icon-flow-colored-model-external-azure-ml.png";
                    }
                    case "databricks": {
                        return "ml/dku-icon-flow-colored-model-external-databricks.png";
                    }
                }
            }
        }
        switch (this.getType()) {
            case CLUSTERING: {
                return "ml/dku-icon-flow-colored-machine-learning-clustering.png";
            }
            case PREDICTION: {
                if (this.miniTask.backendType == MLTask.BackendType.KERAS) {
                    return "ml/dku-icon-flow-colored-machine-learning-deep-learning.png";
                }
                switch (((PredictionMLTask)this.miniTask).predictionType) {
                    case TIMESERIES_FORECAST: {
                        return "ml/dku-icon-flow-colored-machine-learning-timeseries.png";
                    }
                    case CAUSAL_REGRESSION: 
                    case CAUSAL_BINARY_CLASSIFICATION: {
                        return "ml/dku-icon-flow-colored-machine-learning-causal.png";
                    }
                    case DEEP_HUB_IMAGE_CLASSIFICATION: 
                    case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                        return "ml/dku-icon-flow-colored-machine-computer-vision.png";
                    }
                }
            }
        }
        return "ml/dku-icon-flow-colored-machine-learning-regression.png";
    }

    public AgentSettings getAgentSettings(SavedModelInlineVersion smiv) {
        switch (this.savedModelType) {
            case PYTHON_AGENT: {
                return smiv.pythonAgentSettings;
            }
            case PLUGIN_AGENT: {
                return smiv.pluginAgentSettings;
            }
            case TOOLS_USING_AGENT: {
                return smiv.toolsUsingAgentSettings;
            }
        }
        throw new IllegalStateException("getAgentSettings shouldn't only be used with agent, used with " + String.valueOf((Object)this.savedModelType) + " instead.");
    }

    public Set<String> getLlmConnections() {
        HashSet<String> llmConnections = new HashSet<String>();
        switch (this.savedModelType) {
            case TOOLS_USING_AGENT: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    LLMStructuredRef llmRef;
                    ToolsUsingAgentSettings agentSettings = (ToolsUsingAgentSettings)this.getAgentSettings(smiv);
                    if (agentSettings.llmId == null || agentSettings.llmId.isEmpty() || (llmRef = LLMStructuredRef.decodeId(agentSettings.llmId)) == null || llmRef.connection == null) continue;
                    llmConnections.add(LLMStructuredRef.decodeId((String)agentSettings.llmId).connection);
                }
                break;
            }
            case PLUGIN_AGENT: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    CustomAgentMeta meta = AgentTypesRegistry.getMeta(smiv.pluginAgentType);
                    for (LLMStructuredRef llmRef : CustomAgentDependenciesCollector.collectLlmDependencies(meta.getAgentDesc().params, smiv.pluginAgentConfig).values()) {
                        if (llmRef == null || llmRef.connection == null) continue;
                        llmConnections.add(llmRef.connection);
                    }
                }
                break;
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    RAGLLMSettings settings = smiv.ragllmSettings;
                    if (settings.llmId == null) continue;
                    LLMStructuredRef ref = LLMStructuredRef.decodeId(settings.llmId);
                    if (ref.connection != null) {
                        llmConnections.add(ref.connection);
                    }
                    if (!settings.hasGuardrailsEnabled()) continue;
                    if (settings.ragSpecificGuardrails.embeddingModelId != null) {
                        LLMStructuredRef guardrailEmbedRef = LLMStructuredRef.decodeId(settings.ragSpecificGuardrails.embeddingModelId);
                        if (ref.connection != null) {
                            llmConnections.add(guardrailEmbedRef.connection);
                        }
                    }
                    if (settings.ragSpecificGuardrails.llmId == null) continue;
                    LLMStructuredRef guardrailCompletionRef = LLMStructuredRef.decodeId(settings.ragSpecificGuardrails.llmId);
                    if (ref.connection == null) continue;
                    llmConnections.add(guardrailCompletionRef.connection);
                }
                break;
            }
        }
        return llmConnections;
    }

    public void remapLLmConnections(Map<String, String> replacements) {
        switch (this.savedModelType) {
            case TOOLS_USING_AGENT: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    ToolsUsingAgentSettings agentSettings = (ToolsUsingAgentSettings)this.getAgentSettings(smiv);
                    if (agentSettings.llmId == null || agentSettings.llmId.isEmpty()) continue;
                    LLMStructuredRef oldRef = LLMStructuredRef.decodeId(agentSettings.llmId);
                    if (oldRef.connection == null || !replacements.containsKey(oldRef.connection)) continue;
                    LLMStructuredRef newRef = oldRef.withOtherConnection(replacements.get(oldRef.connection));
                    agentSettings.llmId = newRef.encodeToId();
                }
                break;
            }
            case PLUGIN_AGENT: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    CustomAgentMeta meta = AgentTypesRegistry.getMeta(smiv.pluginAgentType);
                    for (Map.Entry<String, LLMStructuredRef> entry : CustomAgentDependenciesCollector.collectLlmDependencies(meta.getAgentDesc().params, smiv.pluginAgentConfig).entrySet()) {
                        String newConnection;
                        LLMStructuredRef llmRef = entry.getValue();
                        if (llmRef == null || llmRef.connection == null || (newConnection = replacements.get(llmRef.connection)) == null) continue;
                        LLMStructuredRef newLlmRef = llmRef.withOtherConnection(newConnection);
                        JsonUtils.addFieldIfNotNull(smiv.pluginAgentConfig, entry.getKey(), newLlmRef.encodeToId());
                    }
                }
                break;
            }
            case RETRIEVAL_AUGMENTED_LLM: {
                for (SavedModelInlineVersion smiv : this.inlineVersions) {
                    RetrievalAugmentedLLMConnectionsUtils.remapConnections(smiv.ragllmSettings, replacements);
                }
                break;
            }
        }
    }

    public Optional<SavedModelInlineVersion> getVersion(String versionId) {
        return this.inlineVersions.stream().filter(version -> versionId.equals(version.versionId)).findFirst();
    }

    @UIModel
    public static enum SavedModelType {
        DSS_MANAGED(SavedModelHandlingType.INTERNAL),
        MLFLOW_PYFUNC(SavedModelHandlingType.EXTERNAL_MLFLOW),
        PROXY_MODEL(SavedModelHandlingType.EXTERNAL_MLFLOW),
        PYTHON_AGENT(SavedModelHandlingType.PYTHON_AGENT),
        PLUGIN_AGENT(SavedModelHandlingType.PLUGIN_AGENT),
        TOOLS_USING_AGENT(SavedModelHandlingType.TOOLS_USING_AGENT),
        LLM_GENERIC(SavedModelHandlingType.LLM_GENERIC),
        RETRIEVAL_AUGMENTED_LLM(SavedModelHandlingType.RETRIEVAL_AUGMENTED_LLM);

        public final SavedModelHandlingType savedModelHandlingType;

        private SavedModelType(SavedModelHandlingType savedModelHandlingType) {
            this.savedModelHandlingType = savedModelHandlingType;
        }

        public boolean isAgent() {
            return this == PYTHON_AGENT || this == PLUGIN_AGENT || this == TOOLS_USING_AGENT;
        }

        public boolean isRetrievalAugmentedLlm() {
            return this == RETRIEVAL_AUGMENTED_LLM;
        }

        public boolean isML() {
            return this == DSS_MANAGED || this == MLFLOW_PYFUNC || this == PROXY_MODEL;
        }
    }

    public static enum ModelPublishPolicy {
        MANUAL,
        UNCONDITIONAL;

    }

    public static class SavedModelInlineVersion {
        public String versionId;
        public String description;
        @Nullable
        public VersionTag versionTag;
        @Nullable
        public VersionTag creationTag;
        @JSON.FileTransient
        @Nullable
        public String code;
        public PythonAgentSettings pythonAgentSettings = new PythonAgentSettings();
        @Nullable
        public String pluginAgentType;
        @Nullable
        public JsonObject pluginAgentConfig;
        public PluginAgentSettings pluginAgentSettings = new PluginAgentSettings();
        public ToolsUsingAgentSettings toolsUsingAgentSettings = new ToolsUsingAgentSettings();
        public RAGLLMSettings ragllmSettings = new RAGLLMSettings();
        public JsonObject quickTestQuery;
        public GuardrailsPipelineSettings guardrailsPipelineSettings = new GuardrailsPipelineSettings();

        public SavedModelInlineVersion() {
        }

        public SavedModelInlineVersion(SavedModelInlineVersion other) {
            this.versionId = other.versionId;
            this.description = other.description;
            this.code = other.code;
            this.versionTag = (VersionTag)JSON.parse((String)JSON.json((Object)other.versionTag), VersionTag.class);
            this.creationTag = (VersionTag)JSON.parse((String)JSON.json((Object)other.creationTag), VersionTag.class);
            this.pythonAgentSettings = new PythonAgentSettings(other.pythonAgentSettings);
            this.pluginAgentType = other.pluginAgentType;
            if (other.pluginAgentConfig != null) {
                this.pluginAgentConfig = other.pluginAgentConfig.deepCopy();
            }
            this.pluginAgentSettings = new PluginAgentSettings(other.pluginAgentSettings);
            this.toolsUsingAgentSettings = new ToolsUsingAgentSettings(other.toolsUsingAgentSettings);
            if (other.quickTestQuery != null) {
                this.quickTestQuery = other.quickTestQuery.deepCopy();
            }
            this.guardrailsPipelineSettings = new GuardrailsPipelineSettings(other.guardrailsPipelineSettings);
        }
    }

    public static enum SavedModelHandlingType {
        INTERNAL,
        EXTERNAL_MLFLOW,
        PYTHON_AGENT,
        PLUGIN_AGENT,
        TOOLS_USING_AGENT,
        LLM_GENERIC,
        RETRIEVAL_AUGMENTED_LLM;

    }

    public static class PythonAgentSettings
    extends AgentSettings {
        public CodeEnvSelection codeEnvSelection = new CodeEnvSelection();
        public List<AgentDependency> dependencies = new ArrayList<AgentDependency>();
        public Boolean supportsImageInputs = false;

        public PythonAgentSettings() {
        }

        public PythonAgentSettings(PythonAgentSettings other) {
            super(other);
            this.codeEnvSelection = new CodeEnvSelection(other.codeEnvSelection);
            for (AgentDependency otherDependency : other.dependencies) {
                this.dependencies.add(new AgentDependency(otherDependency));
            }
            this.supportsImageInputs = other.supportsImageInputs;
        }
    }

    public static class PluginAgentSettings
    extends AgentSettings {
        public PluginAgentSettings() {
        }

        public PluginAgentSettings(PluginAgentSettings other) {
            super(other);
        }
    }

    public static class ToolsUsingAgentSettings
    extends AgentSettings {
        public List<UsedTool> tools = new ArrayList<UsedTool>();
        public String llmId;
        public String systemPromptAppend;
        public LLMClient.CompletionSettings completionSettings = new LLMClient.CompletionSettings();

        public ToolsUsingAgentSettings() {
        }

        public ToolsUsingAgentSettings(ToolsUsingAgentSettings other) {
            super(other);
            for (UsedTool otherTool : other.tools) {
                this.tools.add(new UsedTool(otherTool));
            }
            this.llmId = other.llmId;
            this.systemPromptAppend = other.systemPromptAppend;
            this.completionSettings = (LLMClient.CompletionSettings)JSON.deepCopy((Object)other.completionSettings);
        }
    }

    public static class AgentSettings {
        public List<AbstractSQLConnection.CustomDatabaseProperty> dkuProperties = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
        public int maxParallelRequestsPerProcess = 4;
        public ContainerExecSelection containerExecSelection = new ContainerExecSelection(ContainerExecSelection.ContainerExecMode.NONE);

        public AgentSettings() {
        }

        public AgentSettings(AgentSettings other) {
            this.dkuProperties = (List)JSON.deepCopy(other.dkuProperties);
            this.maxParallelRequestsPerProcess = other.maxParallelRequestsPerProcess;
            this.containerExecSelection = new ContainerExecSelection(other.containerExecSelection);
        }
    }

    public static class UsedTool {
        public String toolRef;
        public String additionalDescription;

        public UsedTool() {
        }

        public UsedTool(UsedTool other) {
            this.toolRef = other.toolRef;
            this.additionalDescription = other.additionalDescription;
        }
    }

    public static class AgentDependency {
        public ITaggingService.TaggableType type;
        public String ref;

        public AgentDependency() {
        }

        public AgentDependency(ITaggingService.TaggableType type, String ref) {
            this.type = type;
            this.ref = ref;
        }

        public AgentDependency(AgentDependency other) {
            this.type = other.type;
            this.ref = other.ref;
        }
    }

    public static class SavedModelListItem
    extends TaggableObjectsService.TaggableListItem {
        public final MLTask.MLTaskType type;
        public final MLTask.BackendType backendType;
        @Nullable
        public PredictionMLTask.PredictionType predictionType;
        public SavedModelType savedModelType;
        public String pluginAgentType;
        @Nullable
        public ProxyModelConfiguration proxyModelConfiguration;
        public long versionsCount = 0L;

        public SavedModelListItem(SavedModel model) {
            super(model);
            this.type = model.getType();
            this.backendType = model.miniTask != null ? model.miniTask.backendType : null;
            this.savedModelType = model.savedModelType;
            if (this.savedModelType == SavedModelType.PLUGIN_AGENT && !model.inlineVersions.isEmpty()) {
                this.pluginAgentType = model.inlineVersions.get((int)0).pluginAgentType;
            }
            this.proxyModelConfiguration = model.proxyModelConfiguration;
            if (this.type == MLTask.MLTaskType.PREDICTION) {
                this.predictionType = ((PredictionMLTask)model.miniTask).predictionType;
            }
        }

        @Override
        public ITaggingService.TaggableType getTaggableType() {
            return ITaggingService.TaggableType.SAVED_MODEL;
        }
    }
}

