/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.coreservices.flow.ISavedModelsCRUDService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.SMMgmtService;
import com.dataiku.dip.analysis.ml.SMStatus;
import com.dataiku.dip.analysis.ml.SMVersionHeader;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionSavedModelStateService;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ResultsReaderBase;
import com.dataiku.dip.analysis.model.CompatibilityWithReason;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.core.SavedModelOriginInfo;
import com.dataiku.dip.analysis.model.prediction.CausalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datalayer.utils.SchemaComparator;
import com.dataiku.dip.externalml.mlflow.MLFlowModelVersionInfo;
import com.dataiku.dip.partitioning.PartitioningUtils;
import com.dataiku.dip.partitioning.StratifiedModelUtils;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.ModelVersionDeletedEvent;
import com.dataiku.dip.server.notifications.backend.TaggableObjectChangedEvent;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.RecipeSchemaChangeService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.RWTransactionRef;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.utils.PathUtils;
import com.dataiku.dss.shadelib.org.apache.commons.io.FileUtils;
import com.dataiku.j2ts.annotations.UIModel;
import com.dataiku.scoring.builders.Build;
import com.dataiku.scoring.builders.DimensionType;
import com.google.common.collect.HashMultimap;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.InvalidClassException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PredictionSMMgmtService
extends SMMgmtService {
    @Autowired
    private ISavedModelsCRUDService crudService;
    @Autowired
    private PubSubService pubSub;
    @Autowired
    private PredictionSavedModelStateService smStateService;
    @Autowired
    private RecipeSchemaChangeService recipeSchemaChangeService;
    private static final String DSS_PIPELINE_META = "dss_pipeline_meta.json";
    private static final Logger logger = Logger.getLogger((String)"dku.ml.savedmodels");

    public void cleanUnreferencedPartitionedModelsNoFail(SavedModel sm) {
        if (!sm.isPartitioned()) {
            return;
        }
        PredictionSavedModelStateService.SavedModelState initialSmState = this.smStateService.compute(sm);
        if (initialSmState.unknown || initialSmState.isBeingTrained) {
            logger.info((Object)String.format("Saved model state %s is not safe for cleaning of unreferenced partitioned models, skipping", initialSmState));
            return;
        }
        Set<File> partVersionsToDelete = this.computeUnreferencedPartitionedModels(sm);
        if (partVersionsToDelete.isEmpty()) {
            logger.info((Object)"No unreferenced partitioned model found");
            return;
        }
        PredictionSavedModelStateService.SavedModelState currentSmState = this.smStateService.compute(sm);
        if (currentSmState.unknown || currentSmState.isBeingTrained) {
            logger.info((Object)String.format("Saved model state %s is not safe for cleaning of unreferenced partitioned models, skipping", currentSmState));
            return;
        }
        if (currentSmState.lastModifiedOn > initialSmState.lastModifiedOn) {
            logger.info((Object)String.format("Saved model %s has changed, skipping cleaning of unreferenced partitioned models", sm.id));
            return;
        }
        logger.info((Object)String.format("Actually performing deletion of unreferenced partitioned models for saved model %s", sm.id));
        for (File dirToDelete : partVersionsToDelete) {
            this.forceDeleteNoFail(dirToDelete);
        }
    }

    private Set<File> computeUnreferencedPartitionedModels(SavedModel sm) {
        logger.info((Object)String.format("Start listing unreferenced partitioned models for deletion, for saved model %s", sm.id));
        HashSet<File> partVersionsToDelete = new HashSet<File>();
        HashMultimap usedPartitions = HashMultimap.create();
        for (SMVInfo smvInfo : this.listValidVersions(sm, false)) {
            PartitionedModelExtract pme = null;
            try {
                pme = smvInfo.fmi.getPartitionedModelExtract();
            }
            catch (Throwable e) {
                logger.warn((Object)String.format("Failed to parse model information for fmi '%s'", smvInfo.fmi), e);
            }
            if (pme == null) continue;
            for (Map.Entry<String, String> entry : pme.versions.entrySet()) {
                usedPartitions.put((Object)entry.getKey(), (Object)entry.getValue());
            }
        }
        File pversionsFolder = MLPaths.savedModelPartVersionBaseFolder(sm);
        if (pversionsFolder.isDirectory()) {
            for (File partition : pversionsFolder.listFiles((FileFilter)DKUFileUtils.FileFilter.DIRECTORIES)) {
                for (File modelPartVersion : partition.listFiles((FileFilter)DKUFileUtils.FileFilter.DIRECTORIES)) {
                    String partitionId = partition.getName();
                    String partitionName = PartitioningUtils.decode(partitionId);
                    FullModelId partitionFmi = new FullModelId(sm.projectKey, sm.getId(), sm.activeVersion, partitionName, modelPartVersion.getName());
                    if (!partitionFmi.isModelUsable()) {
                        logger.debug((Object)("Ignoring non model directory " + modelPartVersion.getAbsolutePath()));
                        continue;
                    }
                    if (usedPartitions.containsEntry((Object)partitionName, (Object)modelPartVersion.getName())) continue;
                    partVersionsToDelete.add(modelPartVersion);
                    logger.info((Object)("Registering version " + modelPartVersion.getName() + " for partition " + partitionName + " of model " + sm.id + " saved as " + partitionId + " on disk for deletion"));
                }
            }
        }
        return partVersionsToDelete;
    }

    public void cleanTemporaryVersionsForPartitionedSavedModels(SavedModel sm, String jobId) {
        if (!sm.isPartitioned()) {
            logger.warn((Object)"Cannot clean temporary version for not-partitioned model, skipping");
            return;
        }
        List temporaryVersionsToDelete = this.listValidVersions(sm, false).stream().filter(smvInfo -> !smvInfo.isActiveOrFinal && smvInfo.smo != null && StringUtils.equals((String)smvInfo.smo.jobId, (String)jobId)).map(smvInfo -> smvInfo.fmi).collect(Collectors.toList());
        if (temporaryVersionsToDelete.size() == 0) {
            return;
        }
        logger.info((Object)String.format("Deleting temporary versions of Saved model (%s), for jobId (%s) associated corresponding to following fmis: %s", sm.id, jobId, temporaryVersionsToDelete));
        for (FullModelId fmiToDelete : temporaryVersionsToDelete) {
            this.forceDeleteNoFail(fmiToDelete.getModelFolder());
        }
        logger.info((Object)String.format("Done deleting temporary versions of Saved model (%s), for jobId (%s)", sm.id, jobId));
    }

    private void forceDeleteNoFail(File directory) {
        try {
            DKUFileUtils.forceDelete((File)directory);
        }
        catch (IOException e) {
            logger.warn((Object)String.format("Failed deleting directory '%s'", directory), (Throwable)e);
        }
    }

    @Override
    public List<FullModelId> listUsableVersions(SavedModel sm) {
        return this.listValidVersions(sm, true).stream().map(info -> info.fmi).collect(Collectors.toList());
    }

    private List<SMVInfo> listValidVersions(SavedModel sm, boolean onlyFinalVersions) {
        HashMap<FullModelId, SMVInfo> ret = new HashMap<FullModelId, SMVInfo>();
        HashMap<String, Pair> jobIdToFmiWithSmo = new HashMap<String, Pair>();
        for (File f : PredictionSMMgmtService.listVersionFolders(sm)) {
            FullModelId fmi = new FullModelId(sm.projectKey, sm.id, f.getName());
            try {
                if (!fmi.isModelUsable()) continue;
                if (!sm.isPartitioned()) {
                    ret.put(fmi, new SMVInfo(fmi, true));
                    continue;
                }
                if (!fmi.getSmOriginFile().exists()) continue;
                SavedModelOriginInfo smo = fmi.getSmOrigin();
                if (null == smo.jobId) {
                    ret.put(fmi, new SMVInfo(fmi, true, smo));
                    continue;
                }
                boolean activeVersion = StringUtils.equals((String)sm.activeVersion, (String)fmi.getSavedModelVersionID());
                if (activeVersion) {
                    ret.put(fmi, new SMVInfo(fmi, true, smo));
                } else if (!onlyFinalVersions) {
                    ret.put(fmi, new SMVInfo(fmi, false, smo));
                }
                if (smo.jobIdUpdate == null || jobIdToFmiWithSmo.containsKey(smo.jobId) && ((SavedModelOriginInfo)((Pair)jobIdToFmiWithSmo.get((Object)smo.jobId)).second).jobIdUpdate >= smo.jobIdUpdate) continue;
                jobIdToFmiWithSmo.put(smo.jobId, new Pair((Object)fmi, (Object)smo));
            }
            catch (IOException e) {
                logger.warn((Object)("Failed to parse model version " + fmi.getSavedModelVersionID()), (Throwable)e);
            }
        }
        for (String id : jobIdToFmiWithSmo.keySet()) {
            Pair info = (Pair)jobIdToFmiWithSmo.get(id);
            ret.put((FullModelId)info.first, new SMVInfo((FullModelId)info.first, true, (SavedModelOriginInfo)info.second));
        }
        return new ArrayList<SMVInfo>(ret.values());
    }

    public PredictionSMStatus getStatus_NT(SavedModel sm) {
        return this.getStatus_NT(sm, true);
    }

    public PredictionSMStatus getStatus_NT(SavedModel sm, boolean onlyFinalVersions) {
        TransactionContext.assertNoAttachedTransaction();
        PredictionSMStatus ret = new PredictionSMStatus();
        ret.activeVersionId = sm.activeVersion;
        ret.task = (PredictionMLTask)sm.miniTask;
        switch (sm.savedModelType.savedModelHandlingType) {
            case INTERNAL: {
                ret.task = (PredictionMLTask)sm.miniTask;
                for (SMVInfo version : this.listValidVersions(sm, onlyFinalVersions)) {
                    FullModelId fmi = version.fmi;
                    PredictionSMVersionHeader vh = new PredictionSMVersionHeader();
                    vh.versionId = fmi.getSavedModelVersionID();
                    vh.active = vh.versionId.equals(ret.activeVersionId);
                    try {
                        PredictionModelDetails pmd = PredictionResultsReader.makeModelDetails(fmi);
                        if (pmd instanceof ClassicalPredictionModelDetails) {
                            ((ClassicalPredictionModelDetails)pmd).headTaskCMW = ((PredictionMLTask.ClassicalPredictionMLTask)ret.task).modeling.metrics.costMatrixWeights;
                        }
                        vh.snippet = PredictionResultsReader.makeSnippet(pmd);
                        ((PredictionModelSnippetData)vh.snippet).savedModelType = SavedModel.SavedModelType.DSS_MANAGED;
                        ((PredictionModelSnippetData)vh.snippet).sessionDate = ((PredictionModelSnippetData)vh.snippet).trainInfo == null ? 0L : ((PredictionModelSnippetData)vh.snippet).trainInfo.endTime;
                        ((PredictionModelSnippetData)vh.snippet).partitions = fmi.isPartitionedBaseModel() ? fmi.getPartitionedModelExtract() : null;
                        ret.versions.add(vh);
                    }
                    catch (Exception e) {
                        logger.warn((Object)"Failed to create model snippet", (Throwable)e);
                    }
                }
                break;
            }
            case EXTERNAL_MLFLOW: {
                for (SMVInfo version : this.listValidVersions(sm, onlyFinalVersions)) {
                    FullModelId fmi = version.fmi;
                    PredictionSMVersionHeader vh = new PredictionSMVersionHeader();
                    vh.versionId = fmi.getSavedModelVersionID();
                    vh.active = vh.versionId.equals(ret.activeVersionId);
                    try {
                        MLFlowModelVersionInfo mlFlowImportedModelMetadata = fmi.getMLflowImportedModelMetadata();
                        PredictionModelDetails pmd = PredictionResultsReader.makeModelDetails(fmi);
                        vh.snippet = PredictionResultsReader.makeSnippet(pmd);
                        ((PredictionModelSnippetData)vh.snippet).savedModelType = mlFlowImportedModelMetadata.isProxyModel() ? SavedModel.SavedModelType.PROXY_MODEL : SavedModel.SavedModelType.MLFLOW_PYFUNC;
                        ((PredictionModelSnippetData)vh.snippet).sessionDate = pmd.sessionDate;
                        ((PredictionModelSnippetData)vh.snippet).trainInfo.modelId = fmi.getSavedModelVersionID();
                        ((PredictionModelSnippetData)vh.snippet).userMeta.name = fmi.getSavedModelVersionID();
                        ((PredictionModelSnippetData)vh.snippet).importedOn = mlFlowImportedModelMetadata.importedOn;
                        ((PredictionModelSnippetData)vh.snippet).mlflowOrigin = mlFlowImportedModelMetadata.origin;
                        ((PredictionModelSnippetData)vh.snippet).timeCreated = mlFlowImportedModelMetadata.timeCreated;
                        ((PredictionModelSnippetData)vh.snippet).flavorsLabels = mlFlowImportedModelMetadata.flavorsLabels;
                        ((PredictionModelSnippetData)vh.snippet).pyfuncLabels = mlFlowImportedModelMetadata.pyfuncLabels;
                        ((PredictionModelSnippetData)vh.snippet).fromDatabricksConnection = mlFlowImportedModelMetadata.fromDatabricksConnection;
                        ((PredictionModelSnippetData)vh.snippet).fromDatabricksModelName = mlFlowImportedModelMetadata.fromDatabricksModelName;
                        ((PredictionModelSnippetData)vh.snippet).useUnityCatalog = mlFlowImportedModelMetadata.useUnityCatalog;
                        if (mlFlowImportedModelMetadata.isProxyModel()) {
                            ((PredictionModelSnippetData)vh.snippet).proxyModelConfiguration = mlFlowImportedModelMetadata.proxyModelVersionConfiguration.proxyModelConfiguration;
                        }
                        ((PredictionModelSnippetData)vh.snippet).proxyModelEndpointInfo = mlFlowImportedModelMetadata.proxyModelEndpointInfo;
                        ((PredictionModelSnippetData)vh.snippet).inputFormat = mlFlowImportedModelMetadata.inputFormat;
                        ((PredictionModelSnippetData)vh.snippet).outputFormat = mlFlowImportedModelMetadata.outputFormat;
                        ((PredictionModelSnippetData)vh.snippet).mlflowClassLabels = mlFlowImportedModelMetadata.classLabels;
                        ((PredictionModelSnippetData)vh.snippet).mlflowEvaluationDatasetSmartName = mlFlowImportedModelMetadata.evaluationDatasetSmartName;
                        ((PredictionModelSnippetData)vh.snippet).mlflowSignatureAndFormatsGuessingDatasetSmartName = mlFlowImportedModelMetadata.signatureAndFormatsGuessingDataset;
                        ((PredictionModelSnippetData)vh.snippet).mlflowGuessingDatasetSamplingParam = mlFlowImportedModelMetadata.guessingDatasetSamplingParam;
                        ((PredictionModelSnippetData)vh.snippet).mlflowEvaluationSamplingParam = mlFlowImportedModelMetadata.evaluationSamplingParam;
                        ((PredictionModelSnippetData)vh.snippet).mlflowEvaluationTargetColumnName = mlFlowImportedModelMetadata.targetColumnName;
                        ret.versions.add(vh);
                    }
                    catch (Exception e) {
                        logger.warn((Object)"Failed to create model snippet", (Throwable)e);
                    }
                }
                break;
            }
            case PYTHON_AGENT: 
            case PLUGIN_AGENT: 
            case TOOLS_USING_AGENT: 
            case LLM_GENERIC: 
            case RETRIEVAL_AUGMENTED_LLM: {
                throw new IllegalArgumentException("Should not be here");
            }
        }
        return ret;
    }

    private void deleteModelVersion_NT(File versionFolder) throws IOException {
        TransactionContext.assertNoAttachedTransaction();
        FilesystemACLUtils.removeACLRestrictiveMask(versionFolder);
        DKUFileUtils.forceDelete((File)versionFolder);
    }

    public void deleteVersion_NT(SavedModel sm, String version, boolean removeIntermediate) throws IOException {
        TransactionContext.assertNoAttachedTransaction();
        if (StringUtils.equals((String)sm.activeVersion, (String)version)) {
            throw new IllegalArgumentException("Cannot delete a model's active version");
        }
        logger.info((Object)("Removing model " + sm.getId() + " version " + version));
        File versionFolder = MLPaths.savedModelVersionFolder(sm, version);
        FullModelId fmi = new FullModelId(sm.projectKey, sm.getId(), version);
        if (!fmi.isModelUsable()) {
            throw new IllegalArgumentException("Can't find valid model version: " + version);
        }
        if (removeIntermediate && sm.isPartitioned()) {
            this.removeModelAndIntermediateVersions_NT(sm, version);
        } else {
            this.deleteModelVersion_NT(versionFolder);
        }
        this.pubSub.publish((DSSEvent)new ModelVersionDeletedEvent(new FullModelId(sm.getProjectKey(), sm.getId(), version).toString()));
    }

    private void removeModelAndIntermediateVersions_NT(SavedModel sm, String version) {
        TransactionContext.assertNoAttachedTransaction();
        assert (sm.isPartitioned());
        String jobId = null;
        FullModelId mainFmi = new FullModelId(sm.projectKey, sm.id, version);
        HashSet<FullModelId> fmisToRemove = new HashSet<FullModelId>();
        fmisToRemove.add(mainFmi);
        try {
            File originFile = mainFmi.getSmOriginFile();
            if (!originFile.isFile()) {
                logger.warn((Object)(mainFmi.getModelFolder().getAbsolutePath() + " does not seem to be a model folder or lacks origin info"));
                return;
            }
            SavedModelOriginInfo smOrigin = mainFmi.getSmOrigin();
            if (smOrigin.origin == SavedModelOriginInfo.Origin.TRAINED_FROM_RECIPE) {
                jobId = smOrigin.jobId;
            }
        }
        catch (IOException e) {
            logger.warn((Object)("Failed to parse model version " + version), (Throwable)e);
            return;
        }
        if (StringUtils.isNotBlank((String)jobId)) {
            logger.info((Object)("Removing also intermediate versions of model " + sm.getId() + " version " + version));
            for (SMVInfo versionInfo : this.listValidVersions(sm, false)) {
                if (versionInfo.isActiveOrFinal) continue;
                FullModelId fmi = versionInfo.fmi;
                try {
                    File originFile = fmi.getSmOriginFile();
                    if (!originFile.isFile()) {
                        logger.warn((Object)("No origin recorded for model " + String.valueOf(fmi)));
                        continue;
                    }
                    SavedModelOriginInfo smOrigin = fmi.getSmOrigin();
                    if (!StringUtils.equals((String)smOrigin.jobId, (String)jobId)) continue;
                    logger.info((Object)("Adding intermediate model " + sm.getId() + " version " + version + " for removal"));
                    fmisToRemove.add(fmi);
                }
                catch (IOException e) {
                    logger.warn((Object)("Failed to process model version " + String.valueOf(fmi)), (Throwable)e);
                }
            }
        }
        for (FullModelId fmiToRemove : fmisToRemove) {
            try {
                logger.info((Object)("Actually deleting model with fmi: " + String.valueOf(fmiToRemove)));
                this.deleteModelVersion_NT(fmiToRemove.getModelFolder());
            }
            catch (IOException e) {
                logger.warn((Object)("Failed to delete fmi " + String.valueOf(fmiToRemove)), (Throwable)e);
            }
        }
    }

    public String createFromMLTask(FullModelId fmi, SavedModel targetSM, boolean suffixVersion) throws Exception {
        Object versionId;
        targetSM.activeVersion = versionId = fmi.isModelPartition() ? "" + System.currentTimeMillis() : "initial";
        this.copyFromMLTask(fmi, targetSM, (String)versionId);
        if (suffixVersion) {
            this.crudService.addDeployedSavedModelVersionSuffixAndUpdateLabel(targetSM, (String)versionId);
        }
        return versionId;
    }

    public String updateFromMLTask(FullModelId fmi, SavedModel targetSM, boolean addSuffix) throws Exception {
        String newVersionId = "" + System.currentTimeMillis();
        this.copyFromMLTask(fmi, targetSM, newVersionId);
        if (addSuffix) {
            this.crudService.addDeployedSavedModelVersionSuffixAndUpdateLabel(targetSM, newVersionId);
        }
        return newVersionId;
    }

    public void createFromPartitionedMLTask(FullModelId fmi, SavedModel targetSM) throws Exception {
        HashMap<String, String> partitionToVersion = new HashMap<String, String>();
        for (FullModelId fmiPart : StratifiedModelUtils.fetchPartitionFmis(fmi)) {
            String partitionName = fmiPart.getPartitionName();
            try {
                ModelTrainInfo mti = fmiPart.parseModelFile("train_info.json", ModelTrainInfo.class);
                if (ModelTrainInfo.ModelTrainState.DONE != mti.state) {
                    logger.warn((Object)("Skipping non DONE partition " + partitionName));
                    continue;
                }
                String partVersionId = this.createFromMLTask(fmiPart, targetSM, false);
                partitionToVersion.put(partitionName, partVersionId);
            }
            catch (Exception e) {
                logger.warn((Object)("Failed to turn model partition " + partitionName + " into a saved model"), (Throwable)e);
            }
        }
        targetSM.activeVersion = this.createFromMLTask(fmi.getPartitionedBaseModel(), targetSM, false);
        this.preparePartitionedSavedModelFolder(fmi, targetSM, partitionToVersion, targetSM.activeVersion);
        this.crudService.addDeployedSavedModelVersionSuffixAndUpdateLabel(targetSM, targetSM.activeVersion);
    }

    public String updateFromPartitionedMLTask(FullModelId fmi, SavedModel targetSM) throws Exception {
        HashMap<String, String> partitionToVersion = new HashMap<String, String>();
        String versionId = this.updateFromMLTask(fmi.getPartitionedBaseModel(), targetSM, true);
        for (FullModelId fmiPart : StratifiedModelUtils.fetchPartitionFmis(fmi)) {
            String partitionName = fmiPart.getPartitionName();
            try {
                ModelTrainInfo mti = fmiPart.parseModelFile("train_info.json", ModelTrainInfo.class);
                if (ModelTrainInfo.ModelTrainState.DONE != mti.state) {
                    logger.warn((Object)("Skipping non DONE partition " + partitionName));
                    continue;
                }
                String partVersionId = this.updateFromMLTask(fmiPart, targetSM, false);
                partitionToVersion.put(partitionName, partVersionId);
            }
            catch (Exception e) {
                logger.warn((Object)("Failed to turn model partition " + partitionName + " into a saved model"), (Throwable)e);
            }
        }
        this.preparePartitionedSavedModelFolder(fmi, targetSM, partitionToVersion, versionId);
        return versionId;
    }

    private void preparePartitionedSavedModelFolder(FullModelId fmi, SavedModel targetSM, Map<String, String> partitionToVersion, String newVersionId) throws IOException {
        FullModelId newBaseModelFmi = new FullModelId(targetSM.projectKey, targetSM.id, newVersionId);
        File versionFolder = newBaseModelFmi.getModelFolder();
        PartitionedModelExtract extract = fmi.getPartitionedModelExtract();
        extract.versions = partitionToVersion;
        Iterator<PartitionedModelExtract.PartitionedModelSummary> iterator = extract.summaries.values().iterator();
        while (iterator.hasNext()) {
            PartitionedModelExtract.PartitionedModelSummary summary = iterator.next();
            if (!partitionToVersion.containsKey(summary.snippet.partitionName)) {
                extract.decreaseState(summary.state);
                iterator.remove();
                continue;
            }
            String partitionName = summary.snippet.partitionName;
            String partitionVersion = partitionToVersion.get(partitionName);
            summary.snippet.fullModelId = new FullModelId(targetSM.projectKey, targetSM.id, newVersionId, partitionName, partitionVersion).toString();
        }
        extract.setStatesToReused();
        newBaseModelFmi.savePartFile(extract);
        FileUtils.copyFileToDirectory((File)fmi.getSessionFile("core_params.json"), (File)versionFolder);
        Map.Entry<String, String> somePartAndVersion = partitionToVersion.entrySet().iterator().next();
        String encodedPartitionName = PartitioningUtils.encode(somePartAndVersion.getKey());
        File somePartitionFolder = MLPaths.savedModelPartVersionFolder(targetSM, encodedPartitionName, somePartAndVersion.getValue());
        DKUFileUtils.copyDirectory((File)somePartitionFolder, (File)versionFolder, EnumSet.of(DKUFileUtils.CopyDirectoryFlags.NoReplacing), (FileFilter[])new FileFilter[]{f -> f.getName().endsWith(".json")});
        File dssPipelineFile = DKUFileUtils.getWithinFollowLink((File)somePartitionFolder, (String[])new String[]{"dss_pipeline_model.gz"});
        if (dssPipelineFile.exists()) {
            FileUtils.copyFileToDirectory((File)dssPipelineFile, (File)versionFolder);
        }
    }

    private PredictionMLTask createMiniTask(FullModelId fmi, ModelDetailsBase details_) throws IOException {
        if (details_ instanceof ClassicalPredictionModelDetails) {
            ClassicalPredictionModelDetails details = (ClassicalPredictionModelDetails)details_;
            PredictionMLTask.ClassicalPredictionMLTask pmlTask = new PredictionMLTask.ClassicalPredictionMLTask();
            pmlTask.modeling = new PredictionModelingParams(pmlTask.predictionType);
            pmlTask.predictionType = details.coreParams.prediction_type;
            pmlTask.targetVariable = details.coreParams.target_variable;
            pmlTask.partitionedModel = details.coreParams.partitionedModel;
            pmlTask.modeling.metrics = details.modeling.metrics;
            PredictionMLTask.ClassicalPredictionMLTask trainTask = (PredictionMLTask.ClassicalPredictionMLTask)fmi.parseSessionFile("mltask.json", MLTask.class);
            pmlTask.backendType = trainTask.backendType;
            pmlTask.sparkParams.sparkConf = trainTask.sparkParams.sparkConf;
            pmlTask.sparkParams.sparkExecutionEngine = trainTask.sparkParams.sparkExecutionEngine;
            pmlTask.sparkParams.sparkUseGlobalMetastore = trainTask.sparkParams.sparkUseGlobalMetastore;
            pmlTask.sparkParams.sparkPreparedDFStorageLevel = trainTask.sparkParams.sparkPreparedDFStorageLevel;
            pmlTask.envSelection = trainTask.envSelection;
            pmlTask.managedFolderSmartId = null;
            return pmlTask;
        }
        if (details_ instanceof DeepHubPredictionModelDetails) {
            DeepHubPredictionModelDetails details = (DeepHubPredictionModelDetails)details_;
            PredictionMLTask.DeepHubPredictionMLTask pmlTask = new PredictionMLTask.DeepHubPredictionMLTask();
            pmlTask.modeling = DeepHubPreTrainModelingParams.build(details.coreParams.prediction_type);
            pmlTask.predictionType = details.coreParams.prediction_type;
            pmlTask.targetVariable = details.coreParams.target_variable;
            pmlTask.modeling.metrics = details.modeling.metrics;
            PredictionMLTask trainTask = (PredictionMLTask)fmi.parseSessionFile("mltask.json", MLTask.class);
            pmlTask.backendType = trainTask.backendType;
            pmlTask.envSelection = trainTask.envSelection;
            return pmlTask;
        }
        if (details_ instanceof TimeseriesForecastingModelDetails) {
            TimeseriesForecastingModelDetails details = (TimeseriesForecastingModelDetails)details_;
            PredictionMLTask.TimeseriesForecastingMLTask pmlTask = new PredictionMLTask.TimeseriesForecastingMLTask();
            pmlTask.modeling = new PredictionModelingParams(pmlTask.predictionType);
            pmlTask.predictionType = details.coreParams.prediction_type;
            pmlTask.targetVariable = details.coreParams.target_variable;
            pmlTask.timeVariable = details.coreParams.timeVariable;
            pmlTask.timeseriesIdentifiers = details.coreParams.timeseriesIdentifiers;
            pmlTask.timestepParams = details.coreParams.timestepParams;
            pmlTask.predictionLength = details.coreParams.predictionLength;
            pmlTask.quantilesToForecast = details.coreParams.quantilesToForecast;
            pmlTask.customTrainTestSplit = details.coreParams.customTrainTestSplit;
            pmlTask.customTrainTestIntervals = details.coreParams.customTrainTestIntervals;
            pmlTask.evaluationParams = details.coreParams.evaluationParams;
            pmlTask.partitionedModel = details.coreParams.partitionedModel;
            pmlTask.modeling.metrics = details.modeling.metrics;
            PredictionMLTask trainTask = (PredictionMLTask)fmi.parseSessionFile("mltask.json", MLTask.class);
            pmlTask.backendType = trainTask.backendType;
            pmlTask.envSelection = trainTask.envSelection;
            return pmlTask;
        }
        if (details_ instanceof CausalPredictionModelDetails) {
            CausalPredictionModelDetails details = (CausalPredictionModelDetails)details_;
            PredictionMLTask.CausalPredictionMLTask pmlTask = new PredictionMLTask.CausalPredictionMLTask();
            pmlTask.predictionType = details.coreParams.prediction_type;
            pmlTask.targetVariable = details.coreParams.target_variable;
            pmlTask.positiveClass = details.coreParams.positive_class;
            pmlTask.treatmentVariable = details.coreParams.treatment_variable;
            pmlTask.controlValue = details.coreParams.control_value;
            pmlTask.modeling = new PredictionModelingParams(pmlTask.predictionType);
            pmlTask.modeling.metrics = details.modeling.metrics;
            pmlTask.modeling.propensityModeling = details.modeling.propensityModeling;
            return pmlTask;
        }
        throw new IllegalArgumentException("Unsupported model details: " + details_.getClass().getSimpleName());
    }

    public static void fillPipelineMeta(FullModelId fmi) throws IOException {
        ResolvedCoreParams coreParams_ = fmi.getResolvedCoreParams();
        if (!(coreParams_ instanceof ResolvedPredictionCoreParams)) {
            throw new InvalidClassException("Resolved core parameters don't match requirements for prediction task: found " + coreParams_.getClass().getName());
        }
        if (MLTask.BackendType.DEEP_HUB.equals((Object)coreParams_.backendType) || PredictionMLTask.PredictionType.TIMESERIES_FORECAST.equals((Object)((ResolvedPredictionCoreParams)coreParams_).prediction_type) || PredictionMLTask.PredictionType.CAUSAL_REGRESSION.equals((Object)((ResolvedPredictionCoreParams)coreParams_).prediction_type) || PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION.equals((Object)((ResolvedPredictionCoreParams)coreParams_).prediction_type)) {
            return;
        }
        ResolvedClassicalPredictionCoreParams coreParams = (ResolvedClassicalPredictionCoreParams)coreParams_;
        if (coreParams.prediction_type == PredictionMLTask.PredictionType.TIMESERIES_FORECAST) {
            return;
        }
        PreTrainPredictionModelingParams rpmp = fmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
        if (rpmp.algorithm.meta.isJavaCompatible(coreParams)) {
            ResolvedClassicalPredictionPreprocessingParams prep = (ResolvedClassicalPredictionPreprocessingParams)JSON.parseFile((File)fmi.getPreprocessingFile("rpreprocessing_params.json"), ResolvedClassicalPredictionPreprocessingParams.class);
            CompatibilityWithReason javaCompatibility = prep.getJavaCompatibility(rpmp.algorithm.backendType);
            if (javaCompatibility.compatible) {
                Build.DssPipelineMeta meta = null;
                if (fmi.isPartitionedBaseModel()) {
                    assert (coreParams.isPartitioned()) : "Partitioned model FMI with disabled partitioning";
                    Set<String> partitions = fmi.listTrainedPartitions();
                    for (String partition : partitions) {
                        FullModelId pfmi = fmi.getModelPartition(partition);
                        if (!pfmi.getModelFile(DSS_PIPELINE_META).exists()) continue;
                        meta = pfmi.parseModelFile(DSS_PIPELINE_META, Build.DssPipelineMeta.class);
                        break;
                    }
                    if (meta == null) {
                        logger.warn((Object)("Couldn't find a partition meta for model " + fmi.toString()));
                        meta = new Build.DssPipelineMeta();
                    }
                    meta.partitions = partitions.toArray(new String[0]);
                    meta.partitioning = coreParams.partitionedModel.dimensionNames.toArray(new String[0]);
                    if (null != coreParams.partitionedModel.dimensionTypes) {
                        meta.partitioningTypes = coreParams.partitionedModel.dimensionTypes.toArray(new DimensionType[0]);
                    }
                } else {
                    meta = fmi.parseModelFile(DSS_PIPELINE_META, Build.DssPipelineMeta.class);
                }
                if (meta == null) {
                    logger.warn((Object)("Couldn't find a partition meta for model " + fmi.toString()));
                    meta = new Build.DssPipelineMeta();
                }
                meta.trainedWithDSSVersion = ApplicationConfigurator.getDSSVersion().product_version;
                meta.isValid = true;
                switch (coreParams.prediction_type) {
                    case REGRESSION: {
                        meta.type = Build.DssPipelineMeta.ModelType.REGRESSION;
                        break;
                    }
                    case BINARY_CLASSIFICATION: {
                        meta.type = rpmp.algorithm.meta.hasProbabilities(rpmp) ? Build.DssPipelineMeta.ModelType.BINARY_PROBABILISTIC : Build.DssPipelineMeta.ModelType.CLASSIFICATION_ONLY;
                        break;
                    }
                    case MULTICLASS: {
                        meta.type = rpmp.algorithm.meta.hasProbabilities(rpmp) ? Build.DssPipelineMeta.ModelType.MULTICLASS_PROBABILISTIC : Build.DssPipelineMeta.ModelType.CLASSIFICATION_ONLY;
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)coreParams.prediction_type));
                    }
                }
                JSON.prettyToFile((Object)meta, (File)fmi.getModelFile(DSS_PIPELINE_META));
            } else {
                logger.info((Object)("Not java compatible: " + javaCompatibility.reason));
            }
        }
    }

    public static void copyPipelineFromMLTask(FullModelId fmi, File targetFolder, boolean withSplitData) throws IOException {
        SplitDesc.SplitRef sr = fmi.getSplitRef();
        File srcSplitFolder = fmi.getTaskLoc().getSplitsFolder();
        SplitDesc sd = (SplitDesc)JSON.parseFile((File)DKUFileUtils.getWithinFollowLink((File)srcSplitFolder, (String[])new String[]{sr.splitInstanceId + ".json"}), SplitDesc.class);
        DKUFileUtils.mkdirs((File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{"split"}));
        if (withSplitData) {
            File savedModelSplitFolder = DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{"split"});
            if ((sd.params.streamAll || sd.params.kfold) && sd.fullPath != null) {
                sd.fullPath = PredictionSMMgmtService.copyFileWithNewName("full", DKUFileUtils.getWithinFollowLink((File)srcSplitFolder, (String[])new String[]{sd.fullPath}), savedModelSplitFolder);
            } else if (sd.trainPath != null && sd.testPath != null) {
                sd.trainPath = PredictionSMMgmtService.copyFileWithNewName("train", DKUFileUtils.getWithinFollowLink((File)srcSplitFolder, (String[])new String[]{sd.trainPath}), savedModelSplitFolder);
                sd.testPath = PredictionSMMgmtService.copyFileWithNewName("test", DKUFileUtils.getWithinFollowLink((File)srcSplitFolder, (String[])new String[]{sd.testPath}), savedModelSplitFolder);
            }
        }
        JSON.prettyToFile((Object)sd, (File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{"split/split.json"}));
        FileUtils.copyFileToDirectory((File)fmi.getSessionFile("script.json"), (File)targetFolder);
        File inputDatasetSchemaFile = fmi.getSessionFile("input_dataset_schema.json");
        if (inputDatasetSchemaFile.exists()) {
            FileUtils.copyFileToDirectory((File)inputDatasetSchemaFile, (File)targetFolder);
        }
        for (File f : fmi.getPreprocessingFolder().listFiles()) {
            if (!DKUFileUtils.isFile((File)f) || !com.dataiku.dip.utils.StringUtils.endsWithAny((String)f.getName(), (String[])new String[]{".json", ".pkl", ".pkl.gz", ".csv", ".log"})) continue;
            FileUtils.copyFileToDirectory((File)f, (File)targetFolder);
        }
        for (File f : fmi.getModelFolder().listFiles()) {
            if (DKUFileUtils.isFile((File)f) && com.dataiku.dip.utils.StringUtils.endsWithAny((String)f.getName(), (String[])new String[]{".json", ".pkl", ".csv", ".dss", ".gz", ".h5", ".keras", ".bin"})) {
                FileUtils.copyFileToDirectory((File)f, (File)targetFolder);
            }
            if (!DKUFileUtils.isDirectory((File)f) || !f.getName().matches("fold_\\d*$")) continue;
            DKUFileUtils.mkdirs((File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{f.getName()}));
            File savedModelFoldFolder = DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{f.getName()});
            DKUFileUtils.copyDirectory((File)f, (File)savedModelFoldFolder);
        }
        FileUtils.copyFileToDirectory((File)fmi.getSessionFile("core_params.json"), (File)targetFolder);
    }

    public static String getLatestSMVersionId(List<PredictionSMVersionHeader> versions) {
        PredictionSMVersionHeader lastVersion = null;
        for (PredictionSMVersionHeader version : versions) {
            if (lastVersion != null && ((PredictionModelSnippetData)version.snippet).trainDate <= ((PredictionModelSnippetData)lastVersion.snippet).trainDate) continue;
            lastVersion = version;
        }
        return lastVersion == null ? null : lastVersion.versionId;
    }

    protected void copyFromMLTask(FullModelId fmi, SavedModel targetSM, String versionId) throws Exception {
        DKUFileUtils.mkdirs((File)MLPaths.savedModelCoreFolder(targetSM));
        File targetFolder = MLPaths.savedModelVersionFolder(targetSM, versionId);
        if (fmi.isModelPartition()) {
            String encodedPartitionName = PartitioningUtils.encode(fmi.getPartitionName());
            targetFolder = MLPaths.savedModelPartVersionFolder(targetSM, encodedPartitionName, versionId);
        }
        assert (!targetFolder.isDirectory());
        DKUFileUtils.mkdirs((File)targetFolder);
        PredictionModelDetails details = PredictionResultsReader.makeModelDetails(fmi);
        targetSM.miniTask = this.createMiniTask(fmi, details);
        PredictionSMMgmtService.fillPipelineMeta(fmi);
        PredictionSMMgmtService.copyPipelineFromMLTask(fmi, targetFolder, true);
        File posttrainFolder = fmi.getPostOperationsFolder();
        if (posttrainFolder.isDirectory()) {
            DKUFileUtils.copyDirectory((File)posttrainFolder, (File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{"posttrain"}));
        }
        fmi.saveSmOriginFromAnalysis(targetFolder);
        targetSM.lastExportedFrom = fmi.isModelPartition() ? fmi.getPartitionedBaseModel().toString() : fmi.toString();
        File residualsFolder = DKUFileUtils.getWithin((File)fmi.getModelFolder(), (String[])new String[]{"residuals"});
        if (residualsFolder.isDirectory()) {
            DKUFileUtils.copyDirectory((File)residualsFolder, (File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{"residuals"}));
        }
    }

    private static String copyFileWithNewName(String newName, File fileToCopy, File targetFolder) throws IOException {
        String oldName = fileToCopy.getName();
        String trainPathExtension = PathUtils.getSmartExtension((String)oldName);
        String newNameWithExtension = newName + "." + trainPathExtension;
        FileUtils.copyFile((File)fileToCopy, (File)DKUFileUtils.getWithinFollowLink((File)targetFolder, (String[])new String[]{newNameWithExtension}));
        return newNameWithExtension;
    }

    public boolean setActive(SavedModel sm, String versionId) throws Exception {
        boolean unchanged = StringUtils.equals((String)sm.activeVersion, (String)versionId);
        if (unchanged) {
            return false;
        }
        FullModelId newFMI = new FullModelId(sm.projectKey, sm.id, versionId);
        if (!newFMI.isModelUsable()) {
            throw ErrorContext.iaef((String)"The new active version to set (%s) is not valid", (Object)versionId, (Object[])new Object[0]);
        }
        String originalVersion = sm.activeVersion;
        try {
            RWTransactionRef t = TransactionContext.retrieveWrite();
            this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.SAVED_MODEL, sm.projectKey, sm.id, t.getUser(), TaggableObjectChangedEvent.ActionType.SAVED_MODEL_CHANGE_ACTIVE_VERSION));
        }
        catch (Throwable t) {
            logger.info((Object)"Failed to publish event", t);
        }
        sm.activeVersion = versionId;
        this.crudService.save(sm, false, false);
        if (originalVersion == null) {
            return true;
        }
        FullModelId originalFMI = new FullModelId(sm.projectKey, sm.id, originalVersion);
        boolean hasSchemaChanged = this.recipeSchemaChangeService.computeAndPropagateSchemaChanges(sm, originalFMI, newFMI);
        if (hasSchemaChanged) {
            return true;
        }
        SplitDesc originalSplitDesc = ResultsReaderBase.readSplitDesc(originalFMI);
        SplitDesc newSplitDesc = ResultsReaderBase.readSplitDesc(newFMI);
        return originalSplitDesc == null || !SchemaComparator.findIncompatibilities(originalSplitDesc.schema, newSplitDesc.schema, true).isEmpty();
    }

    private static class SMVInfo {
        final FullModelId fmi;
        final boolean isActiveOrFinal;
        @Nullable
        final SavedModelOriginInfo smo;

        private SMVInfo(FullModelId fmi, boolean isActiveOrFinal) {
            this(fmi, isActiveOrFinal, null);
        }

        private SMVInfo(FullModelId fmi, boolean isActiveOrFinal, SavedModelOriginInfo smo) {
            this.fmi = fmi;
            this.isActiveOrFinal = isActiveOrFinal;
            this.smo = smo;
        }
    }

    @UIModel
    public static class PredictionSMStatus
    extends SMStatus<PredictionSMVersionHeader> {
        public PredictionMLTask task;
    }

    public static class PredictionSMVersionHeader
    extends SMVersionHeader<PredictionModelSnippetData> {
        public SavedModelOriginInfo smOrigin;
    }
}

