/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.externalml.mlflow;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.TrainExecutionParams;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.MLflowOrigin;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.NumFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerUtils;
import com.dataiku.dip.savedmodels.externalmodelidentifier.AzureMLEndpointIdentifier;
import com.dataiku.dip.savedmodels.externalmodelidentifier.DatabricksEndpointIdentifier;
import com.dataiku.dip.savedmodels.externalmodelidentifier.ExternalEndpointIdentifier;
import com.dataiku.dip.savedmodels.externalmodelidentifier.SagemakerEndpointIdentifier;
import com.dataiku.dip.savedmodels.externalmodelidentifier.VertexAIEndpointIdentifier;
import com.dataiku.dip.savedmodels.proxymodels.AzureMLProxyConfiguration;
import com.dataiku.dip.savedmodels.proxymodels.VertexAIProxyConfiguration;
import com.dataiku.dip.savedmodels.proxymodelversions.AzureMLProxyModelVersionConfiguration;
import com.dataiku.dip.savedmodels.proxymodelversions.DatabricksProxyModelVersionConfiguration;
import com.dataiku.dip.savedmodels.proxymodelversions.ProxyModelVersionConfiguration;
import com.dataiku.dip.savedmodels.proxymodelversions.VertexAIProxyModelVersionConfiguration;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class MLFlowModelVersionInfo {
    public static final String FILENAME = "mlflow_imported_model.json";
    public long importedOn;
    public long timeCreated;
    public String targetColumnName;
    @JSON.FileTransient
    public String gatherFeaturesFromDataset;
    public String signatureAndFormatsGuessingDataset;
    public String inputFormat;
    public String outputFormat;
    public String pythonCodeEnvName;
    public String pythonVersion;
    public PredictionMLTask.PredictionType predictionType;
    public String fromDatabricksModelName;
    public String fromDatabricksConnection;
    public boolean useUnityCatalog;
    public List<ClassLabel> classLabels = new ArrayList<ClassLabel>();
    public List<SchemaColumn> features = new ArrayList<SchemaColumn>();
    public MetricParams metricParams;
    public List<SimpleKeyValue> flavorsLabels = new ArrayList<SimpleKeyValue>();
    public List<SimpleKeyValue> pyfuncLabels = new ArrayList<SimpleKeyValue>();
    public MLflowOrigin origin;
    @Nullable
    public ProxyModelVersionConfiguration proxyModelVersionConfiguration;
    @Nullable
    public ProxyModelVersionConfiguration.ConsolidatedEndpointInfo proxyModelEndpointInfo;
    public double binaryClassificationThreshold = 0.5;
    public SamplingParam evaluationSamplingParam;
    public String evaluationDatasetSmartName;

    public String getProxyModelConnection() {
        if (this.proxyModelVersionConfiguration != null && this.proxyModelVersionConfiguration.proxyModelConfiguration != null) {
            return this.proxyModelVersionConfiguration.proxyModelConfiguration.connection;
        }
        return null;
    }

    public void fillMinimalCoreParamsOfPredictionDetails(ClassicalPredictionModelDetails details) {
        details.coreParams = this.getMinimalClassicalPredictionCoreParams();
    }

    public ResolvedClassicalPredictionCoreParams getMinimalClassicalPredictionCoreParams() {
        ResolvedClassicalPredictionCoreParams coreParams = new ResolvedClassicalPredictionCoreParams();
        coreParams.prediction_type = this.predictionType;
        coreParams.target_variable = this.targetColumnName;
        coreParams.backendType = MLTask.BackendType.PY_MEMORY;
        coreParams.executionParams = new TrainExecutionParams();
        coreParams.executionParams.envSelection = CodeEnvSelection.explicitEnv(this.pythonCodeEnvName);
        coreParams.executionParams.envName = this.pythonCodeEnvName;
        return coreParams;
    }

    public void fillMinimalModelingParamsOfPredictionDetails(ClassicalPredictionModelDetails details) {
        details.modeling = this.getMinimalPretrainModelingParams();
    }

    public PreTrainPredictionModelingParams getMinimalPretrainModelingParams() {
        PreTrainPredictionModelingParams modeling = new PreTrainPredictionModelingParams();
        modeling.algorithm = this.isProxyModel() ? PreTrainPredictionModelingParams.Algorithm.VIRTUAL_PROXY_MODEL : PreTrainPredictionModelingParams.Algorithm.VIRTUAL_MLFLOW_PYFUNC;
        modeling.forcedClassifierThreshold = this.binaryClassificationThreshold;
        if (this.metricParams == null) {
            modeling.metrics = new MetricParams();
            modeling.metrics.costMatrixWeights = new MetricParams.CostMatrixWeights();
        } else {
            modeling.metrics = this.metricParams;
        }
        return modeling;
    }

    public void fillMinimalPreprocessingParamsOfPredictionDetails(ClassicalPredictionModelDetails details) {
        details.preprocessing = this.getMinimalClassicalPredictionPreprocessingParams();
    }

    public ResolvedClassicalPredictionPreprocessingParams getMinimalClassicalPredictionPreprocessingParams() {
        ResolvedClassicalPredictionPreprocessingParams preprocessing = new ResolvedClassicalPredictionPreprocessingParams();
        if (this.predictionType != null) {
            switch (this.predictionType) {
                case BINARY_CLASSIFICATION: 
                case MULTICLASS: {
                    preprocessing.target_remapping = IntStream.range(0, this.classLabels.size()).mapToObj(i -> {
                        PredictionPreprocessingParams.MappingValue mv = new PredictionPreprocessingParams.MappingValue();
                        mv.sourceValue = this.classLabels.get((int)i).label;
                        mv.mappedValue = i;
                        return mv;
                    }).collect(Collectors.toList());
                    break;
                }
            }
        }
        if (this.features != null) {
            preprocessing.per_feature = new LinkedHashMap();
            this.features.forEach(f -> {
                if (f.getName().equals(this.targetColumnName)) {
                    return;
                }
                if (f.getType().isNumeric() || f.getType().isTemporal()) {
                    NumFeaturePreprocessingParams p = new NumFeaturePreprocessingParams();
                    p.name = f.getName();
                    p.type = FeaturePreprocessingParams.FeatureType.NUMERIC;
                    p.role = FeaturePreprocessingParams.Role.INPUT;
                    p.numerical_handling = NumFeaturePreprocessingParams.NumericalHandlingMethod.REGULAR;
                    p.rescaling = NumFeaturePreprocessingParams.RescalingMethod.NONE;
                    preprocessing.per_feature.put(p.name, p);
                } else {
                    CatFeaturePreprocessingParams p = new CatFeaturePreprocessingParams();
                    p.name = f.getName();
                    p.type = FeaturePreprocessingParams.FeatureType.CATEGORY;
                    p.role = FeaturePreprocessingParams.Role.INPUT;
                    preprocessing.per_feature.put(p.name, p);
                }
            });
            if (StringUtils.isNotEmpty((String)this.targetColumnName) && this.predictionType != null) {
                FeaturePreprocessingParams p;
                switch (this.predictionType) {
                    case REGRESSION: {
                        p = new NumFeaturePreprocessingParams();
                        p.name = this.targetColumnName;
                        p.type = FeaturePreprocessingParams.FeatureType.NUMERIC;
                        p.role = FeaturePreprocessingParams.Role.TARGET;
                        break;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        p = new CatFeaturePreprocessingParams();
                        p.name = this.targetColumnName;
                        p.type = FeaturePreprocessingParams.FeatureType.CATEGORY;
                        p.role = FeaturePreprocessingParams.Role.TARGET;
                        break;
                    }
                    default: {
                        throw new NotImplementedException("Unhandled prediction type: " + String.valueOf((Object)this.predictionType));
                    }
                }
                preprocessing.per_feature.put(this.targetColumnName, p);
            }
        }
        return preprocessing;
    }

    @Nullable
    public ExternalEndpointIdentifier getExternalEndpointIdentifier() {
        if (this.proxyModelVersionConfiguration == null) {
            return null;
        }
        switch (this.proxyModelVersionConfiguration.protocol) {
            case "sagemaker": {
                if (this.proxyModelEndpointInfo == null || this.proxyModelEndpointInfo instanceof ProxyModelVersionConfiguration.ErrorEndpointInfo) {
                    return null;
                }
                return new SagemakerEndpointIdentifier(((SageMakerUtils.DSSSageMakerConsolidatedEndpointInfo)this.proxyModelEndpointInfo).arn);
            }
            case "azure-ml": {
                AzureMLProxyModelVersionConfiguration azureMLProxyModelVersionConfiguration = (AzureMLProxyModelVersionConfiguration)this.proxyModelVersionConfiguration;
                AzureMLProxyConfiguration azureMLProxyConfiguration = (AzureMLProxyConfiguration)azureMLProxyModelVersionConfiguration.proxyModelConfiguration;
                return new AzureMLEndpointIdentifier(azureMLProxyModelVersionConfiguration.endpointName, azureMLProxyConfiguration.workspace, azureMLProxyConfiguration.resourceGroup, azureMLProxyConfiguration.subscriptionId);
            }
            case "vertex-ai": {
                VertexAIProxyModelVersionConfiguration vertexAIProxyModelVersionConfiguration = (VertexAIProxyModelVersionConfiguration)this.proxyModelVersionConfiguration;
                VertexAIProxyConfiguration vertexAIProxyConfiguration = (VertexAIProxyConfiguration)vertexAIProxyModelVersionConfiguration.proxyModelConfiguration;
                return new VertexAIEndpointIdentifier(vertexAIProxyConfiguration.project_id, vertexAIProxyConfiguration.region, vertexAIProxyModelVersionConfiguration.endpoint_id);
            }
            case "databricks": {
                if (this.proxyModelEndpointInfo == null || this.proxyModelEndpointInfo instanceof ProxyModelVersionConfiguration.ErrorEndpointInfo) {
                    return null;
                }
                DatabricksProxyModelVersionConfiguration.DSSDatabricksConsolidatedEndpointInfo dbxProxyModelEndpointInfo = (DatabricksProxyModelVersionConfiguration.DSSDatabricksConsolidatedEndpointInfo)this.proxyModelEndpointInfo;
                return new DatabricksEndpointIdentifier(dbxProxyModelEndpointInfo.id, dbxProxyModelEndpointInfo.host, dbxProxyModelEndpointInfo.endpointName);
            }
        }
        return null;
    }

    public boolean isProxyModel() {
        return this.proxyModelVersionConfiguration != null && this.proxyModelVersionConfiguration.proxyModelConfiguration != null;
    }

    public static class ClassLabel {
        public String label;
    }

    public static enum ModelOutputStyle {
        AUTO_DETECT,
        PREDICTION_ONLY;

    }
}

