(function(){
'use strict';

let app = angular.module('dataiku.analysis.mlcore');

app.service('PMLTrainTestPolicies', function() {
    return {
        trainTestPolicies: [
            ["SPLIT_MAIN_DATASET", "Split the dataset"],
            ["EXPLICIT_FILTERING_SINGLE_DATASET_MAIN", "Explicit extracts from the dataset"],
            ["EXPLICIT_FILTERING_TWO_DATASETS", "Explicit extracts from two datasets"],
            ["SPLIT_OTHER_DATASET", "Split another dataset"],
            ["EXPLICIT_FILTERING_SINGLE_DATASET_OTHER", "Explicit extracts from another dataset"]
        ],
        trainTestPoliciesDesc: function(inputDatasetSmartName) { return [
            `Split a subset of ${inputDatasetSmartName}`,
            `Use two extracts from ${inputDatasetSmartName}, one for train, one for test`,
            `Use two extracts from two different datasets, one for train, one for test`,
            `Split a subset of another dataset, compatible with ${inputDatasetSmartName}`,
            `Use two extracts from another dataset, one for train, one for test`,
        ]}
    };
});

app.service("EmbeddingService", function(ActiveProjectKey, DataikuAPI) {

    const _getAvailableModels = (purpose, onError)  => {
        return new Promise((resolve, _) => {
            DataikuAPI.pretrainedModels.listAvailableLLMs(ActiveProjectKey.get(), purpose).then(function(data) {
                resolve(data.data);
            }).catch(onError);
    })};

    const metadataFetcher = (onError) => {
        let availableTextEmbeddingModels = null;
        let availableImageEmbeddingModels = null;
        _getAvailableModels("TEXT_EMBEDDING_EXTRACTION", onError).then(function(available) {
            availableTextEmbeddingModels = available;
        });
        _getAvailableModels("IMAGE_EMBEDDING_EXTRACTION", onError).then(function(available) {
            availableImageEmbeddingModels = available;
        });

        function getSentenceEmbedding(featurePreprocessing) {
            if (!availableTextEmbeddingModels) { // Not initialized
                return null;
            }
            const llm =  availableTextEmbeddingModels.identifiers.find(llm => llm.id === featurePreprocessing.sentenceEmbeddingModel);
            if (llm) {
                return {
                    friendlyName : llm.friendlyName,
                    tokensLimit : llm.maxTokensLimit
                };
            }
            else { // legacy / code env model
                return {
                    friendlyName : featurePreprocessing.sentenceEmbeddingModel + " (Code env based)",
                    tokensLimit : featurePreprocessing.maxSequenceLength // for code env resource models the effective tokens limit might have been decreased compared to the model maximum limit
                };
            }
        }

        function getImageEmbedding(featurePreprocessing) {
            if (!availableImageEmbeddingModels) { // Not initialized
                return null;
            }
            const model = availableImageEmbeddingModels.identifiers.find(model => model.id === featurePreprocessing.pretrainedModelsParams.structureRefId);
            if (model) {
                return {
                    friendlyName: model.friendlyName
                }
            } else {
                return {
                    friendlyName: featurePreprocessing.pretrainedModelsParams.structureRefId
                }
            }
        }

        return {
            getSentenceEmbedding,
            getImageEmbedding
        }
    };

    return {
        metadataFetcher
    }
});

app.service('PMLSettings', function($filter, PMLTrainTestPolicies, DeepHubMetricsService, CustomMetricIDService, Logger) {
    const cst = {
        taskF: function(backend) { return {
            bcEvaluationMetrics: [
                ["ACCURACY", "Accuracy"],
                ["PRECISION", "Precision"],
                ["RECALL", "Recall"],
                ["F1", "F1-score"],
                ["COST_MATRIX", "Cost Matrix"],
                ["ROC_AUC", "ROC AUC"],
                ["CUMULATIVE_LIFT", "Lift"],
                ["AVERAGE_PRECISION", "Average Precision"],
                ["LOG_LOSS", "Log Loss"]
            ],
            mcEvaluationMetrics: [
                ["ACCURACY", "Accuracy"],
                ["PRECISION", "Precision"],
                ["RECALL", "Recall"],
                ["F1", "F1-score"],
                ["ROC_AUC", "ROC AUC"],
                ["AVERAGE_PRECISION", "Average Precision"],
                ["LOG_LOSS", "Log Loss"]
            ],
            regressionEvaluationMetrics: [
                ["EVS", "Explained Variance Score"],
                ["MAPE", "Mean Absolute Percentage Error"],
                ["MAE", "Mean Absolute Error"],
                ["MSE", "Mean Squared Error"],
                ["RMSE", "Root Mean Square Error"],
                ["RMSLE", "Root Mean Square Logarithmic Error"],
                ["R2", "R2 Score"]
            ],
            timeseriesEvaluationMetrics: [
                ["MASE", "Mean Absolute Scaled Error"],
                ["MAPE", "Mean Absolute Percentage Error"],
                ["SMAPE", "Symmetric Mean Absolute Percentage Error"],
                ["MAE", "Mean Absolute Error"],
                ["MEAN_ABSOLUTE_QUANTILE_LOSS", "Mean Absolute Quantile Loss"],
                ["MEAN_WEIGHTED_QUANTILE_LOSS", "Mean Weighted Quantile Loss"],
                ["MSE", "Mean Squared Error"],
                ["RMSE", "Root Mean Square Error"],
                ["MSIS", "Mean Scaled Interval Score"],
                ["ND", "Normalized Deviation"],
            ],
            causalEvaluationMetrics: [
                ["AUUC", "Area Under the Uplift Curve"],
                ["NET_UPLIFT", "Net Uplift at Specified Level"],
                ["QINI", "Qini Score"]
            ],
            crossvalModesRandom: function(backend){
                var base = [
                    ["SHUFFLE", "Simple train/validation split"],
                    ["KFOLD", "K-fold"]
                ]
                if (backend === "PY_MEMORY") {
                    base.push(["CUSTOM", "Custom code"]);
                }
                return base;
            }(backend),
            getCrossvalModesRandomForDocumentation(crossValidationStrategy, mlTaskDesign) {
                let crossvalModeDoc;

                switch (crossValidationStrategy) {
                    case 'SHUFFLE':
                        crossvalModeDoc = 'Simple train/validation split';
                    break;

                    case 'KFOLD':
                        crossvalModeDoc = `${mlTaskDesign.modeling.gridSearchParams.nFolds}-fold cross-validation`;
                    break;

                    case 'CUSTOM':
                        crossvalModeDoc = 'Custom Code';
                    break;
                }

                return crossvalModeDoc;
            },
            crossvalModesWithTime: [["TIME_SERIES_SINGLE_SPLIT", "Time-based train/validation split"],
                                    ["TIME_SERIES_KFOLD", "Time-based K-fold (with overlap)"],
                                    ["CUSTOM", "Custom code"]],
            getCrossvalModesWithTimeForDocumentation(crossValidationStrategy, mlTaskDesign) {
                let crossvalModeDoc;

                switch (crossValidationStrategy) {
                    case 'TIME_SERIES_SINGLE_SPLIT':
                        crossvalModeDoc = 'Simple train/validation split';
                    break;

                    case 'TIME_SERIES_KFOLD':
                        crossvalModeDoc = `Time-based ${mlTaskDesign.modeling.gridSearchParams.nFolds}-fold (with overlap)`;
                    break;

                    case 'CUSTOM':
                        crossvalModeDoc = 'Custom code';
                    break;
                }

                return crossvalModeDoc;
            }
        }}, task: {
            predictionTypes: [
                // tabular types only
                {
                    type: "BINARY_CLASSIFICATION",
                    fullName: "Two-class classification",
                    shortName: "binary",
                    classical: true,
                },
                {
                    type: "MULTICLASS",
                    fullName: "Multiclass classification",
                    shortName: "multiclass",
                    classical: true,
                },
                {
                    type: "REGRESSION",
                    fullName: "Regression",
                    shortName: "regression",
                    classical: true,
                },
                {
                    type: "TIMESERIES_FORECAST",
                    fullName: "Time series forecast",
                    shortName: "forecast",
                    forecast: true,
                },
                {
                    type: "CAUSAL_BINARY_CLASSIFICATION",
                    fullName: "Causal Classification",
                    shortName: "causal classification",
                    causal: true,
                },
                {
                    type: "CAUSAL_REGRESSION",
                    fullName: "Causal Regression",
                    shortName: "causal regression",
                    causal: true,
                },
                {
                    type: null,
                    fullName: "Other prediction",
                    shortName: "other",
                    other: true,
                }
            ],
            isClassification: function(predictionType) {
                return ["BINARY_CLASSIFICATION", "MULTICLASS", "CAUSAL_BINARY_CLASSIFICATION"].includes(predictionType);
            },
            getDefaultEvaluationMetric: function(predictionType) {
                switch (predictionType) {
                    case 'BINARY_CLASSIFICATION':
                    case 'MULTICLASS':
                        return "ROC_AUC";
                    case 'REGRESSION':
                        return 'R2';
                    case 'TIMESERIES_FORECAST':
                        return 'MASE';
                }
            },
            thresholdOptimizationMetrics: [
                ["ACCURACY", "Accuracy"],
                ["F1", "F1-score"],
                ["COST_MATRIX", "Cost Matrix"]
            ],
            trainTestPolicies: PMLTrainTestPolicies.trainTestPolicies,
            trainTestPoliciesDesc: PMLTrainTestPolicies.trainTestPoliciesDesc,
            splitModes: [ ["RANDOM", "Randomly"],
                          ["SORTED", "Based on time variable"] ],
            calibrationMethods: [['NO_CALIBRATION', 'None'], ['SIGMOID', 'Sigmoid (Platt scaling)'], ['ISOTONIC', 'Isotonic Regression']],
        }, names: {
            evaluationMetrics: function() {
                const metrics = {
                    PRECISION: "Precision",
                    RECALL: "Recall",
                    F1: "F1-score",
                    ACCURACY: "Accuracy",
                    EVS : "EVS",
                    MAPE : "MAPE",
                    MAE : "MAE",
                    MSE : "MSE",
                    RMSE: "RMSE",
                    RMSLE: "RMSLE",
                    R2: "R2 Score",
                    PEARSON: "Correlation",
                    COST_MATRIX: "Cost Matrix Gain",
                    LOG_LOSS: "Log Loss",
                    ROC_AUC: "ROC AUC",
                    AVERAGE_PRECISION: "Average Precision",
                    CALIBRATION_LOSS : "Calibration Loss",
                    CUMULATIVE_LIFT : "Lift",
                    DATA_DRIFT: 'Data drift',
                    DATA_DRIFT_PVALUE: 'Data drift p-value',
                    AUUC: "Area under Uplift Curve",
                    QINI: "Qini Score",
                    NET_UPLIFT: "Net Uplift at Specified Level",
                    MASE: "MASE",
                    SMAPE: "Symmetric MAPE",
                    MEAN_ABSOLUTE_QUANTILE_LOSS: "Mean Absolute Quantile Loss",
                    MEAN_WEIGHTED_QUANTILE_LOSS: "Mean Weighted Quantile Loss",
                    MSIS: "MSIS",
                    ND: "Normalized Deviation",
                    WORST_MASE: "Worst MASE",
                    WORST_MAPE: "Worst MAPE",
                    WORST_SMAPE: "Worst SMAPE",
                    WORST_MAE: "Worst MAE",
                    WORST_MSE: "Worst MSE",
                    WORST_MSIS: "Worst MSIS",
                    MIN_KS: "Min KS",
                    MIN_CHISQUARE: "Min Chi-square",
                    MAX_PSI: "Max PSI",
                    PREDICTION_DRIFT_PSI: "Prediction drift PSI",
                    PREDICTION_DRIFT_KS: "Prediction drift KS",
                    PREDICTION_DRIFT_CHISQUARE: "Prediction drift Chi-square",
                    FAITHFULNESS: "Faithfulness",
                    MULTIMODAL_FAITHFULNESS: "Multimodal faithfulness",
                    ANSWER_RELEVANCY: "Answer relevancy",
                    MULTIMODAL_RELEVANCY: "Multimodal relevancy",
                    ANSWER_SIMILARITY: "Answer similarity",
                    ANSWER_CORRECTNESS: "Answer correctness",
                    CONTEXT_RECALL: "Context recall",
                    CONTEXT_PRECISION: "Context precision",
                    BERT_SCORE_PRECISION: 'BERT Score - Precision',
                    BERT_SCORE_RECALL: 'BERT Score - Recall',
                    BERT_SCORE_F1: 'BERT Score - F1 Score',
                    BLEU: 'BLEU',
                    ROUGE_1_PRECISION: 'ROUGE-1 - Precision',
                    ROUGE_1_RECALL: 'ROUGE-1 - Recall',
                    ROUGE_1_F1: 'ROUGE-1 - F1 Score',
                    ROUGE_2_PRECISION: 'ROUGE-2 - Precision',
                    ROUGE_2_RECALL: 'ROUGE-2 - Recall',
                    ROUGE_2_F1: 'ROUGE-2 - F1 Score',
                    ROUGE_L_PRECISION: 'ROUGE-L - Precision',
                    ROUGE_L_RECALL: 'ROUGE-L - Recall',
                    ROUGE_L_F1: 'ROUGE-L - F1 Score',
                    INPUT_TOKENS_PER_ROW: 'Average Input tokens',
                    OUTPUT_TOKENS_PER_ROW: 'Average Output tokens',
                    AVERAGE_TOOL_EXECUTIONS_PER_ROW: 'Average number of tool calls',                                  // Global metric
                    AVERAGE_FAILED_TOOL_EXECUTIONS_PER_ROW: 'Average number of failed tool calls',                    // Global metric
                    AVERAGE_TOOL_EXECUTION_TIME_SECONDS_PER_ROW: 'Average tool call execution time (s)',              // Global metric
                    P95_TOTAL_AGENT_CALL_EXECUTION_TIME_SECONDS_PER_ROW: 'P95 total agent call execution time (s)',   // Global metric
                    SAMPLE_ROW_COUNT: 'Sample row count',                                                             // Global metric
                    TOTAL_TOOL_EXECUTIONS: 'Total tool calls',                                                        // Row-by-row metric
                    TOTAL_FAILED_TOOL_EXECUTIONS: 'Total failed tool calls',                                          // Row-by-row metric
                    TOTAL_TOOL_EXECUTION_TIME_SECONDS: 'Total tool calls execution time (s)',                         // Row-by-row metric
                    TOTAL_AGENT_CALL_EXECUTION_TIME_SECONDS: 'Total agent execution time',                            // Row-by-row metric
                    TOOL_CALL_EXACT_MATCH: 'Median of Tool Call Exact match',
                    TOOL_CALL_PARTIAL_MATCH: 'Median of Tool Call Partial match',
                    TOOL_CALL_PRECISION: 'Median of Tool Call Precision',
                    TOOL_CALL_RECALL: 'Median of Tool Call Recall',
                    TOOL_CALL_F1: 'Median of Tool Call F1 Score',
                    AGENT_GOAL_ACCURACY_WITH_REFERENCE: 'Average Agent Goal Accuracy with reference',
                    AGENT_GOAL_ACCURACY_WITHOUT_REFERENCE: 'Average Agent Goal Accuracy without reference',
                };
                // Adding Deep Hub metrics
                Object.entries(DeepHubMetricsService.metricNameToDescriptionMap()).forEach( element => {
                    if (element[0] in metrics) { // Do not override existing metrics
                        if (metrics[element[0]] !== element[1]) {
                            Logger.warn(`Deephub description for ${element[0]} (${element[1]}) does not match metrics (${metrics[element[0]]})`);
                        }
                    } else {
                        metrics[element[0]] = element[1];
                    }
                });

                return metrics;
            } ()
        }, sort: {
            lowerBetter: [
                'MAE',
                'MSE',
                'RMSE',
                'RMSLE',
                'LOG_LOSS',
                'HAMMINGLOSS',
                'MAPE',
                'MASE',
                'SMAPE',
                'MEAN_ABSOLUTE_QUANTILE_LOSS',
                'MEAN_WEIGHTED_QUANTILE_LOSS',
                'MSIS',
                'ND',
                'CALIBRATION_LOSS'
            ]
        }, normalizedMetrics: [ // metrics that are between 0 and 1
            "ROC_AUC", "AVERAGE_PRECISION", "PRECISION", "RECALL", "F1", "ACCURACY", "EVS", "R2",  "CALIBRATION_LOSS"
        ], algorithmCategories: function(predictionType) {
            if (["BINARY_CLASSIFICATION", "MULTICLASS", "REGRESSION", "CAUSAL_BINARY_CLASSIFICATION", "CAUSAL_REGRESSION"].includes(predictionType)) {
                return {
                    "Linear Models": ['RIDGE_REGRESSION', 'LASSO_REGRESSION', 'LEASTSQUARE_REGRESSION', 'LOGISTIC_REGRESSION', 'GLM_H2O', 'MLLIB_LOGISTIC_REGRESSION', 'MLLIB_LINEAR_REGRESSION', 'SPARKLING_GLM'],
                    "Random Forests": ['RANDOM_FOREST_REGRESSION', 'RANDOM_FOREST_CLASSIFICATION', 'DISTRIBUTED_RF_H2O', 'MLLIB_RANDOM_FOREST', 'SPARKLING_RF'],
                    "Support Vector Machines": ['SVC_CLASSIFICATION', 'SVM_REGRESSION'],
                    "Stochastic Gradient Descent": ['SGD_CLASSIFICATION', 'SGD_REGRESSION'],
                    "Gradient Boosting": ['GBM_H2O', 'GBT_CLASSIFICATION', 'GBT_REGRESSION', 'LIGHTGBM_CLASSIFICATION', 'LIGHTGBM_REGRESSION', 'XGBOOST_CLASSIFICATION', 'XGBOOST_REGRESSION', 'SPARKLING_GBM', 'MLLIB_GBT'],
                    "Decision Tree": ['DECISION_TREE_CLASSIFICATION', 'DECISION_TREE_REGRESSION', 'MLLIB_DECISION_TREE'],
                    "Others": true
                }
            }
            if (predictionType === "TIMESERIES_FORECAST") {
                return {
                    "Baseline Models": ['TRIVIAL_IDENTITY_TIMESERIES', 'SEASONAL_NAIVE'],
                    "Statistical Models": ['AUTO_ARIMA', 'CROSTON', 'SEASONAL_LOESS', 'PROPHET', 'GLUONTS_NPTS_FORECASTER'],
                    "Deep Learning Models": ['GLUONTS_SIMPLE_FEEDFORWARD', 'GLUONTS_DEEPAR', 'GLUONTS_TRANSFORMER', 'GLUONTS_MQCNN'],
                }
            }
        }, noDollarKey: function(k) {
            return !k.startsWith('$') && k != "_name" && k != "datasetColumnId" && k != "userModified";
        }, defaultCustomCode: function(backendType, isRegression, targetVariable) {
            let code;
            if (backendType === "PY_MEMORY") {
                if (isRegression) {
                    code = "# This sample code uses a standard scikit-learn algorithm, the Adaboost regressor.\n\n" +
                        "# Your code must create a 'clf' variable. This clf must be a scikit-learn compatible\n" +
                        "# model, ie, it should:\n" +
                        "#  1. have at least fit(X,y) and predict(X) methods\n" +
                        "#  2. inherit sklearn.base.BaseEstimator\n" +
                        "#  3. handle the attributes in the __init__ function\n" +
                        "#     See: https://doc.dataiku.com/dss/latest/machine-learning/custom-models.html\n\n" +
                        "from sklearn.ensemble import AdaBoostRegressor\n\n"+
                        "clf = AdaBoostRegressor(n_estimators=20)\n"
                } else {
                    code = "# This sample code uses a standard scikit-learn algorithm, the Adaboost classifier.\n\n" +
                        "# Your code must create a 'clf' variable. This clf must be a scikit-learn compatible\n" +
                        "# classifier, ie, it should:\n" +
                        "#  1. have at least fit(X,y) and predict(X) methods\n" +
                        "#  2. inherit sklearn.base.BaseEstimator\n" +
                        "#  3. handle the attributes in the __init__ function\n" +
                        "#  4. have a classes_ attribute\n" +
                        "#  5. have a predict_proba method (optional)\n" +
                        "#     See: https://doc.dataiku.com/dss/latest/machine-learning/custom-models.html\n\n" +
                        "from sklearn.ensemble import AdaBoostClassifier\n\n" +
                        "clf = AdaBoostClassifier(n_estimators=20)\n"
                }
            } else if (backendType === "MLLIB") {
                if (isRegression) {
                    code = "// This sample code uses a standard MLlib algorithm, the RandomForestRegressor.\n\n" +
                        "// import the Estimator from spark.ml\n" +
                        "import org.apache.spark.ml.regression.RandomForestRegressor\n\n" +
                        "// instantiate the Estimator\n" +
                        "new RandomForestRegressor()\n" +
                        "   .setLabelCol(\"" + targetVariable + "\")  // Must be the target column\n" +
                        "   .setFeaturesCol(\"__dku_features\")  // Must always be __dku_features\n" +
                        "   .setPredictionCol(\"prediction\")  // Must always be prediction\n" +
                        "   .setNumTrees(50)\n" +
                        "   .setMaxDepth(8)";
                } else {
                    code = "// This sample code uses a standard MLlib algorithm, the RandomForestClassifier.\n\n" +
                        "// import the Estimator from spark.ml\n" +
                        "import org.apache.spark.ml.classification.RandomForestClassifier\n\n" +
                        "// instantiate the Estimator\n" +
                        "new RandomForestClassifier()\n" +
                        "   .setLabelCol(\"" + targetVariable + "\")  // Must be the target column\n" +
                        "   .setFeaturesCol(\"__dku_features\")  // Must always be __dku_features\n" +
                        "   .setPredictionCol(\"prediction\")    // Must always be prediction\n" +
                        "   .setNumTrees(50)\n" +
                        "   .setMaxDepth(8)";
                }
            }
            return code;
        }, isSpecialFeature: function(featParams) {
            // ONLY FOR KERAS (DEEP LEARNING) BACKEND

            if (!featParams || featParams.role === "REJECT") {
                return false;
            }

            var featType = featParams.type;

            if (featType === "TEXT") {

                let handling = featParams.text_handling;
                if (handling === "CUSTOM") {
                    return true;
                }
            }

            if (featType === "IMAGE") {

                let handling = featParams.image_handling;
                if (handling === "CUSTOM") {
                    return true;
                }
            }

            return false;
        }, hpPrettyName: function(hpAlgorithm, hpName) {
        const prettyNames = {
            "logit": {
                "penalty": "Penalty",
                "multi_class": "Multi class mode",
                "C": "C"
            },
            "rf": {
                "estimators": "Number of trees",
                "max_tree_depth": "Max trees depth",
                "selection_mode": "Feature sampling strategy",
                "min_samples_leaf": "Min samples per leaf",
                "max_features": "Used features",
                "max_feature_prop": "Used features",
                "allow_sparse_matrices": "Sparse matrices allowed"
            },
            "extra_trees": {
                "estimators": "Number of trees",
                "max_tree_depth": "Max trees depth",
                "selection_mode": "Feature sampling strategy",
                "min_samples_leaf": "Min samples per leaf",
                "max_features": "Used features",
                "max_feature_prop": "Used features"
            },
            "xgboost": {
                "missing": "Value treated as missing",
                "booster": "Booster",
                "silent": "Silent",
                "objective": "Objective function",
                "nthread": "Number of threads",
                "seed": "Seed",
                "impute_missing": "Custom missing value",
                "n_estimators": "Actual number of trees",
                "max_depth": "Max trees depth",
                "learning_rate": "Eta (learning rate)",
                "alpha": "Alpha (L1 regularization)",
                "lambda": "Lambda (L2 regularization)",
                "gamma": "Gamma (Min loss reduction to split a leaf)",
                "min_child_weight": "Min sum of instance weight in a child",
                "max_delta_step": "Max delta step",
                "subsample": "Subsample ratio of the training instance",
                "colsample_bytree": "Columns subsample ratio for trees",
                "colsample_bylevel": "Columns subsample ratio for splits / levels",
                "scale_pos_weight": "Balancing of positive and negative weights",
                "base_score": "Global bias (initial prediction score)",
                "enable_early_stopping": "Early stopping",
                "early_stopping_rounds": "Early stopping: max rounds before early stop",
                "allow_sparse_matrices": "Sparse matrices allowed",
                "tweedie_variance_power": "Tweedie variance power"
            },
            "svm": {
                "kernel": "Kernel",
                "C": "C",
                "coef0": "Independent kernel term",
                "tol": "Stopping tolerance",
                "max_iter": "Max iterations",
                "custom_gamma": "Kernel coef (gamma)",
                "gamma": "Kernel coef (gamma)"
            },
            "dt": {
                "max_depth": "Maximum depth",
                "criterion": "Split criterion",
                "min_samples_leaf": "Min. samples per leaf",
                "splitter": "Split strategy"
            },
            "ridge": {
                "alpha": "Alpha"
            },
            "lasso": {
                "alpha": "Alpha"
            },
            "mllib_logit": {
                "reg_param": "Lambda (regularization param)",
                "enet_param": "Alpha (Elastic net param)"
            },
            "mllib_linreg": {
                "reg_param": "Lambda (regularization param)",
                "enet_param": "Alpha (Elastic net param)"
            },
            "mllib_rf": {
                "num_trees": "Number of trees",
                "impurity": "Impurity function",
                "max_bins": "Max. bin used",
                "min_info_gain": "Min. information increment",
                "max_depth": "Maximum depth of tree",
                "step_size": "Step size",
                "subsample_rate": "Ratio data used",
                "subset_strategy": "Subset strategy",
                "min_instance_per_node": "Min. instance per node"
            },
            "mllib_gbt": {
                "num_trees": "Number of trees",
                "impurity": "Impurity function",
                "max_bins": "Max. bin used",
                "min_info_gain": "Min. information increment",
                "max_depth": "Maximum depth of tree",
                "step_size": "Step size",
                "subsample_rate": "Ratio data used",
                "min_instance_per_node": "Min. instance per node"
            },
            "mllib_dt": {
                "max_depth": "Max. depth of tree",
                "max_bins": "Max. bin used",
                "min_info_gain": "Mim. information increment",
                "min_instance_per_node": "Min. instance per node"
            },
            "mllib_naive_bayes": {
                "lambda": "Lambda"
            },
            "gbt": {
                "n_estimators": "Number of boosting stages",
                "learning_rate": "Eta (learning rate)",
                "max_depth": "Max trees depth",
                "max_features": "Number of features",
                "max_features_prop": "Number of features",
                "min_samples_leaf": "Minimum samples per leaf",
                "loss": "Loss",
                "selection_mode": "Feature sampling strategy",
                "allow_sparse_matrices": "Sparse matrices allowed"
            },
            "knn": {
                "k": "K",
                "leaf_size": "Leaf size",
                "p": "p",
                "distance_weighting": "Distance weighting",
                "algorithm": "Algorithm"
            },
            "sgd_grid": {
                "tol": "Stopping tolerance",
                "max_iter": "Max iterations",
                "alpha": "Alpha",
                "l1_ratio": "L1 mixim parameter",
                "loss": "Loss",
                "penalty": "Regularization",
                "n_iter": "Number of iterations",
                "epsilon": "Epsilon"
            },
            "sgd": {
                "alpha": "Alpha",
                "l1_ratio": "L1 mixim parameter",
                "loss": "Loss",
                "penalty": "Regularization",
                "n_iter": "Number of iterations",
                "epsilon": "Epsilon"
            },
            "mllib_logit_grid": {
                "max_iter": "Max iterations"
            },
            "mllib_linreg_grid": {
                "max_iter": "Max iterations"
            },
            "mllib_rf_grid": {
                "subset_strategy": "Feature subset strategy",
                "impurity": "Impurity",
                "max_bins": "Maximum number of bins",
                "max_memory_mb": "Maximum memory",
                "checkpoint_interval": "Check point interval",
                "cache_node_ids": "Cache node IDs",
                "min_info_gain": "Minimum information gain",
                "min_instance_per_node": "Minimum instance per node",
                "subsampling_rate": "Subsampling rate",
                "seed": "Subsampling seed"
            },
            "keras": {
                "epochs": "Number of epochs",
                "oneDimensionalOutput": "1-D output"
            },
            "mllib_gbt_grid": {
                "impurity": "Impurity",
                "max_bins": "Maximum number of bins",
                "max_memory_mb": "Maximum memory",
                "checkpoint_interval": "Check point interval",
                "cache_node_ids": "Cache node IDs",
                "min_info_gain": "Minimum information gain",
                "min_instance_per_node": "Minimum instance per node",
                "subsampling_rate": "Subsampling rate",
                "step_size": "Step size",
                "seed": "Subsampling seed"
            },
            "mllib_dt_grid": {
                "cache_node_ids": "Ids of cache nodes",
                "checkpoint_interval": "Checkpoint interval",
                "max_bins": "Maximum number of bins",
                "max_memory_mb": "Maximum memory",
                "min_info_gain": "Minimum information gain",
                "min_instance_per_node": "Minimum instance per node"
            },
            "neural_network_grid": {
                "alpha": "Alpha",
                "max_iter": "Max iterations",
                "tol": "Convergence tolerance",
                "validation_fraction": "Validation fraction",
                "learning_rate_init": "Intial Learning Rate",
                "batch_size": "Batch size",
                "beta_1": "beta_1",
                "beta_2": "beta_2",
                "epsilon": "epsilon",
                "power_t": "power_t",
                "momentum": "Momentum",
                "activation": "Activation",
                "early_stopping": "Early stopping",
                "solver": "Solver",
                "shuffle": "Shuffle data",
                "auto_batch": "Automatic batching",
                "learning_rate": "Learning rate annealing",
                "nesterovs_momentum": "Use Nesterov momentum"
            },
            "lars_grid": {
                "max_features": "Max number of features"
            },
            "lars": {
                "max_features": "Max number of features",
                "K": "K"
            },
            "vertica_linreg_grid": {
                "maxIterations": "Max number of iterations"
            },
            "lightgbm": {
                "boosting_type": "Booster",
                "n_estimators": "Actual number of trees",
                "max_depth": "Maximum Tree Depth",
                "num_leaves": "Maximum number of leaves",
                "learning_rate": "Learning rate",
                "reg_alpha": "Alpha (L1 regularization)",
                "reg_lambda": "Lambda (L2 regularization)",
                "min_split_gain": "Minimal gain to perform a split",
                "min_child_weight": "Min sum of instance child weight",
                "subsample": "Ratio of the training instance",
                "objective": "Objective function",
                "subsample_freq": "Frequency for bagging",
                "subsample_for_bin": "Ratio for each discrete feature bin",
                "colsample_bytree": "Columns subsample ratio for trees",
                "random_state": "Random State",
                "n_jobs": "Number of jobs",
                "importance_type": "Feature importance type",
                "early_stopping": "Early stopping",
                "early_stopping_rounds": "Early stopping: max rounds before early stop",
                "allow_sparse_matrices": "Sparse matrices allowed"
            },
            "auto_arima_timeseries_params": {
                "m": "Season length"
            },
            "seasonal_loess_timeseries_params": {
                "period": "Season length"
            },
            "deep_neural_network": {
                "batch_size": "Batch size",
                "device": "Device",
                "epochs": "Number of epochs",
                "hidden_layers": "Hidden layers",
                "learning_rate": "Learning rate",
                "max_epochs": "Max number of epochs",
                "units": "Units per layer",
                "reg_l2": "Lambda (L2 regularization)",
                "reg_l1": "Alpha (L1 regularization)",
                "early_stopping_enabled": "Early stopping",
                "early_stopping_patience": "Early stopping: max rounds before early stop",
                "early_stopping_threshold": "Early stopping: threshold"
            }
        }
        return (prettyNames[hpAlgorithm] || {})[hpName] || $filter('niceConst')(hpName);
    }
    };
    /**
     *
     * @param {string} metricId - ID of Metric. One 'builtin', 'CUSTOM' or custom metric id
     * @param {Object[]} customMetrics - list of custom metrics from modeling params
     * @param customEvaluationMetricName - Name of the custom evaluation metric, if set
     */
    cst.sort.lowerIsBetter = function (metricId, customMetrics, customEvaluationMetricName="") {
        if ('CUSTOM' === metricId) {
            metricId = CustomMetricIDService.getCustomMetricId(customEvaluationMetricName);
        }

        const isCustomMetric = CustomMetricIDService.checkMetricIsCustom(metricId);
        if (isCustomMetric) {
            const metric = customMetrics.find(customMetric => customMetric.name === CustomMetricIDService.getCustomMetricName(metricId));
            return metric ? !metric.greaterIsBetter : false;
        }
        return (cst.sort.lowerBetter.indexOf(metricId) !== -1);
    };

    // Consider that input is "Special" if it contains special features
    // In practice, each special input has its own input
    cst.isSpecialInput = function (inputName, perFeature) {
        if (!inputName || !perFeature) {
            return false;
        }
        return Object.values(perFeature).some(f => f.sendToInput == inputName && cst.isSpecialFeature(f));
    };

    return cst;
});

app.service("PartitionedModelsService", function(PMLSettings, Logger) {

    let cst = {
        getPartitionsSnippetStateSize: (snippetData, ...states) => {
            if (!snippetData || !snippetData.partitions || !snippetData.partitions.states) {
                return 0;
            }

            return Object.entries(snippetData.partitions.states)
                .reduce((total, pair) => {
                    const [state, amount] = pair;
                    if (states.includes(state)) {
                        return total + amount;
                    }
                    return total;
                }, 0);
        },

        getTotalAmountOfPartitions: (snippetData) => {
            if (!snippetData || !snippetData.partitions || !snippetData.partitions.states) {
                return 0;
            }

            return Object.values(snippetData.partitions.states)
                .reduce((total, amount) => total + amount, 0);
        },

        getCurrentStep: (snippetData) => {
            if (!snippetData || !snippetData.trainInfo || !snippetData.trainInfo.progress) {
                return 0;
            }

            return (snippetData.trainInfo.progress.top_level_done || []).length + ((snippetData.trainInfo.progress.stack || []).length ? 1 : 0);
        },

        getStepCount: (snippetData) => {
            if (!snippetData || !snippetData.trainInfo || !snippetData.trainInfo.progress) {
                return 0;
            }

            return cst.getCurrentStep(snippetData) + (snippetData.trainInfo.progress.top_level_todo || []).length;
        },

        getPartitionResultMetricGradient: (snippetData, sortMainMetric, currentMetric) => {

            // We are dealing here with sort Metric that may be infinite (i.e. the corresponding metric
            // is undefined)
            if (sortMainMetric === undefined || Math.abs(sortMainMetric) === Number.MAX_VALUE) {
                return "none";
            }

            let ratio;
            if (PMLSettings.normalizedMetrics.includes(currentMetric)
               && !PMLSettings.sort.lowerBetter.includes(currentMetric)) {
                ratio = sortMainMetric;
            } else {

                const existingMetricsList = Object.values(snippetData.partitions.summaries)
                                                  .map( summary => summary.snippet.sortMainMetric)
                                                 // Remove infinite values that may have been introduce for sorting purpose
                                                  .filter(m => Math.abs(m) < Number.MAX_VALUE);
                const metricsMax = Math.max(...existingMetricsList);
                const metricsMin = Math.min(...existingMetricsList);
                const minRatio = 0.05;
                const maxRatio = 1;

                if (metricsMax === metricsMin) {
                    ratio = maxRatio;
                } else {
                    ratio = minRatio + maxRatio * (sortMainMetric - metricsMin) / (metricsMax - metricsMin)
                }
            }

            const greenBaseColor = "#29AF5D";
            return 'linear-gradient(to right, '+ greenBaseColor +' 0%, ' + greenBaseColor + ' '+ (ratio * 100) +'%,rgba(0, 0, 0, 0) '+ (ratio * 100) +'%, rgba(0, 0, 0, 0) 100%)';
        },

        getAggregationExplanation: (metricId, displayName, isCustom=false, hideTestWeightMention) => {
            switch (isCustom ? "CUSTOM" : metricId) {
                case "ACCURACY":
                case "PRECISION":
                case "RECALL":
                case "F1":
                case "COST_MATRIX":
                case "MCC":
                case "HAMMINGLOSS":
                    return "{0} of the global model, using optimal threshold for each partition.".format(displayName);
                // same as MSE/MAPE/MAE but here the name is displayed in lower case instead
                case "LOG_LOSS":
                    return "Log loss of the global model (equal to the average log loss per partition, weighted by test weight).";
                // same as CUMULATIVE_LIFT/CALIBRATION_LOSS but here the name is displayed in upper case instead
                case "ROC_AUC":
                        return "Average ROC AUC per partition weighted by test weight as an approximation of the true ROC AUC.";
                case "AVERAGE_PRECISION":
                case "CUMULATIVE_LIFT":
                case "CALIBRATION_LOSS": {
                    const displayNameLowerCase = displayName.toLowerCase();
                    return "Average {0} per partition weighted by test weight as an approximation of the true {1}.".format(displayNameLowerCase, displayNameLowerCase);
                }
                case "CUSTOM":
                    return "Average custom score per partition weighted by test weight.\n"
                            + "It may be an approximation of the true custom score depending on the way it has been defined.";
                case "MSE":
                case "MAPE":
                case "MAE":
                    if (!hideTestWeightMention) {
                        return "{0} of the global model (equal to the average {1} per partition, weighted by test weight).".format(displayName, displayName);
                    }
                case "RMSE":
                case "RMSLE":
                case "R2":
                case "EVS":
                case "PEARSON":
                case "MASE":
                case "SMAPE":
                case "MEAN_ABSOLUTE_QUANTILE_LOSS":
                case "MEAN_WEIGHTED_QUANTILE_LOSS":
                case "MSIS":
                case "ND":
                    return "{0} of the global model.".format(displayName);
                default:
                    Logger.error("Metric name is not valid");
            }
        },
        isPartitionedModel: (modelData) => {
            return modelData
                && modelData.coreParams
                && modelData.coreParams.partitionedModel
                && modelData.coreParams.partitionedModel.enabled;
        }
    }

    return cst;

});

app.service("BinaryClassificationModelsService", function () {
    function getPercentString(p) {
        if (p < 0.01) {
            return "< 1 %";
        } else if (p > 1) {
            return "100 %";
        }
        else {
            return Math.round(p * 100) + " %";
        }
    }
    function roundCutValue(v) {
        // Cut values are rounded because of a 0.025 step increase that led to some numerical discrepencies and bad comparisons
        // See also `decisions_and_cuts.py`
        return Math.round(1000 * v) / 1000;
    }
    function findCut(cuts, cutToFind) {
            cutToFind = roundCutValue(cutToFind);
            let i = 0;
            for (i = 0; i < cuts.length - 1; i++) {
                if (roundCutValue(cuts[i]) >= cutToFind) {
                    break;
                }
            }
            return i;
    }
    return {
        findCut,
        findCutData: (modelDataPerf, cutToFind) => {
            let pcd = modelDataPerf && modelDataPerf.perCutData;
            if (!pcd) {
                return;
            }
            let i = findCut(pcd.cut, cutToFind);
            var tp = pcd.tp[i], tn = pcd.tn[i], fp = pcd.fp[i], fn = pcd.fn[i];
            var actPos = tp + fn;
            var actNeg = tn + fp;
            var predPos = tp + fp;
            var predNeg = tn + fn;
            var eps = 0.01;
            let ret = { // capitalized = will be graphed
                index: i, cut: roundCutValue(pcd.cut[i]),
                tp: {
                    records: tp,
                    actual: getPercentString(tp / (actPos + eps)),
                    predicted: getPercentString(tp / (predPos + eps))
                },
                tn: {
                    records: tn,
                    actual: getPercentString(tn / (actNeg + eps)),
                    predicted: getPercentString(tn / (predNeg + eps))
                },
                fp: {
                    records: fp,
                    actual: getPercentString(fp / (actNeg + eps)),
                    predicted: getPercentString(fp / (predPos + eps))
                },
                fn: {
                    records: fn,
                    actual: getPercentString(fn / (actPos + eps)),
                    predicted: getPercentString(fn / (predNeg + eps))
                },
                actPos: {
                    records: tp + fn,
                    actual: "100 %",
                    ratio: actPos / (actPos + actNeg) // used in subpop
                },
                actNeg: {
                    records: tn + fp,
                    actual: "100 %",
                    ratio: actNeg / (actPos + actNeg) // used in subpop
                },
                predPos: {
                    records: tp + fp,
                    predicted: "100 %",
                    ratio: predPos / (predPos + predNeg) // used in subpop
                },
                predNeg: {
                    records: tn + fn,
                    predicted: "100 %",
                    ratio: predNeg / (predPos + predNeg) // used in subpop
                },
                Accuracy: pcd.accuracy[i],
                mcc: pcd.mcc[i],
                hammingLoss: pcd.hammingLoss[i],
                Precision: pcd.precision[i],
                Recall: pcd.recall[i],
                "F1-Score": pcd.f1[i],
                customMetricsResults:(pcd.customMetricsResults ? pcd.customMetricsResults.map(item => {
                    return {
                        metric: item.metric,
                        didSucceed: item.didSucceed,
                        error: item.error,
                        value: item.didSucceed ? item.values[i] : null,
                        valuestd: item.valuesstd && item.valuesstd[i] ? item.valuesstd[i] : null
                    }
                }): null)
            };
            const additionalFields = ['cmg', 'accuracystd', 'mccstd', 'hammingLossstd', 'precisionstd', 'recallstd', 'f1std'];
            additionalFields.forEach(key => {
                if (pcd[key]) {
                    ret[key] = pcd[key][i];
                }
            });
            return ret;
        }
    };
});

app.controller("DeepLearningPMLController", function ($scope, $timeout, $interval, $controller, DataikuAPI, PMLSettings, PMLFilteringService, VisualMlCodeEnvCompatibility,
                                                      $state, $stateParams, TopNav, Dialogs, CreateModalFromTemplate, Fn, Logger, $q, CodeBasedEditorUtils) {
    const inputsShown = {};
    $controller("_MLTaskDesignController", { $scope: $scope });

    $scope.deferredAfterInitMlTaskDesign.then(
        () => $scope.retrieveCodeEnvsInfo()
    )

    function insertCode(codeToInsert) {
        //timeout to make sure of an angular safe apply
        $timeout(function() {
            $scope.cm.replaceSelection(`${codeToInsert}\n`, "around");
        });

        $scope.cm.focus();
    }

    function fillFitCodeKeras() {
        if ($scope.mlTaskDesign.modeling.keras.fitCode === undefined) {
            // language=Python
            const stepsPerEpochCode = $scope.mlTaskDesign.modeling.keras.trainOnAllData ? "" : "                        steps_per_epoch=" + $scope.mlTaskDesign.modeling.keras.stepsPerEpoch + ",\n";
            const fitCode = "# A function that builds train and validation sequences.\n" +
                "# You can define your custom data augmentation based on the original train and validation sequences\n\n" +
                "#   build_train_sequence_with_batch_size        - function that returns train data sequence depending on\n" +
                "#                                                 batch size\n" +
                "#   build_validation_sequence_with_batch_size   - function that returns validation data sequence depending on\n" +
                "#                                                 batch size\n" +
                "def build_sequences(build_train_sequence_with_batch_size, build_validation_sequence_with_batch_size):\n" +
                "    \n" +
                "    batch_size = " + $scope.mlTaskDesign.modeling.keras.batchSize + "\n" +
                "    \n" +
                "    train_sequence = build_train_sequence_with_batch_size(batch_size)\n" +
                "    validation_sequence = build_validation_sequence_with_batch_size(batch_size)\n" +
                "    \n" +
                "    return train_sequence, validation_sequence\n\n\n" +
                "# A function that contains a call to fit a model.\n\n" +
                "#   model                 - compiled model\n" +
                "#   train_sequence        - train data sequence, returned in build_sequence\n" +
                "#   validation_sequence   - validation data sequence, returned in build_sequence\n" +
                "#   base_callbacks        - a list of Dataiku callbacks, that are not to be removed. User callbacks can be added to this list\n" +
                "def fit_model(model, train_sequence, validation_sequence, base_callbacks):\n" +
                "    epochs = " + $scope.mlTaskDesign.modeling.keras.epochs + "\n" +
                "    fitfunc = getattr(model, 'fit_generator', model.fit) \n" +
                "    fitfunc(train_sequence,\n" +
                "            epochs=epochs,\n" +
                stepsPerEpochCode +
                "            callbacks=base_callbacks,\n" +
                "            shuffle=" + ($scope.mlTaskDesign.modeling.keras.shuffleData ? "True" : "False") + ")\n";
            $scope.mlTaskDesign.modeling.keras.fitCode = fitCode;
        }
    }

    // Allow transition of Inputs area only on click
    $scope.addEventOnTransition = function() {
        $(".keras-inputs__wrapper").on("transitionend", function() {
            $scope.uiState.canTransition = false;
        })
    };

    $scope.showHideInputs = function() {
        $scope.uiState.canTransition = true;
        $scope.uiState.displayInput = !$scope.uiState.displayInput;
    };

    $scope.startEditInput = function(input) {
        $scope.uiState.currentlyEditing=input;
        $scope.uiState.newEditInputName=input;
    };

    $scope.isBeingEdited = function(input) {
        return $scope.uiState.currentlyEditing === input;
    };

    $scope.editInputIfValid = function() {
        if (!$scope.isValidEditInput()) {
            return;
        }
        var inputIndex = $scope.mlTaskDesign.modeling.keras.kerasInputs.indexOf($scope.uiState.currentlyEditing);
        $scope.mlTaskDesign.modeling.keras.kerasInputs[inputIndex] = $scope.uiState.newEditInputName;

        // Modifying input for each feature in it
        Object.values($scope.mlTaskDesign.preprocessing.per_feature).forEach(function(featParams) {
            if (featParams["sendToInput"] === $scope.uiState.currentlyEditing) {
                featParams["sendToInput"] = $scope.uiState.newEditInputName;
            }
        });

        // Resetting UI variables
        $scope.uiState.currentlyEditing=null;
        $scope.uiState.newEditInputName='';
    };

    $scope.isValidEditInput = function() {
        // Has not change name so far
        if ($scope.uiState.newEditInputName === $scope.uiState.currentlyEditing) {
            return true;
        }
        if ($scope.uiState.newEditInputName === "") {
            return false;
        }
        if ($scope.mlTaskDesign.modeling.keras.kerasInputs.indexOf($scope.uiState.newEditInputName) > -1) {
            return false;
        }
        return true;
    }

    $scope.cancelEditInput = function() {
        // Resetting UI variables
        $scope.uiState.currentlyEditing=null;
        $scope.uiState.newEditInputName='';
    };

    $scope.deleteInput = function(input) {
        Dialogs.confirm($scope, "Delete Deep Learning Input", "Do you want to delete this input ? All its features will be sent to the 'main' input").then(function(data){
            const inputIndex = $scope.mlTaskDesign.modeling.keras.kerasInputs.indexOf(input);
            $scope.mlTaskDesign.modeling.keras.kerasInputs.splice(inputIndex, 1);

            // Sending all features to 'main' input
            Object.values($scope.mlTaskDesign.preprocessing.per_feature).forEach(function (featParams) {
                if (featParams["sendToInput"] === input) {
                    featParams["sendToInput"] = "main";
                }
            });
        });
    };

    $scope.createInputIfValid = function() {
        if (!$scope.isValidNewInput()) {
            return;
        }
        $scope.mlTaskDesign.modeling.keras.kerasInputs.push($scope.uiState.newInputName);
        $scope.uiState.creatingNewInput = false;
        $scope.uiState.newInputName = '';

    };

    $scope.isValidNewInput = function() {
        if (!$scope.uiState.newInputName) {
            return false;
        }
        return $scope.mlTaskDesign.modeling.keras.kerasInputs.indexOf($scope.uiState.newInputName) <= -1;

    };

    $scope.isCreatingInput = function() {
        return $scope.uiState.creatingNewInput;
    };

    $scope.startCreatingInput = function() {
        $scope.uiState.creatingNewInput = true
    };

    $scope.cancelCreateInput = function() {
        $scope.uiState.creatingNewInput = false;
        $scope.uiState.newInputName = '';
    };

    $scope.insertInput = function(input) {

        if (!$scope.isSpecialInput(input)) {
            var code = "input_" + input + " = Input(shape=input_shapes[\""+input+"\"], name=\""+input+"\")";
            insertCode(code);
        } else {


            let deferred = $q.defer();
            let newScope = $scope.$new();
            newScope.input = input;
            newScope.uiState = {processorCodeShown: true};
            newScope.perFeature = $scope.mlTaskDesign.preprocessing.per_feature;

            newScope.insertReadOnlyOptions = $scope.codeMirrorSettingService.get('text/x-python');
            newScope.insertReadOnlyOptions["readOnly"]= "nocursor";
            newScope.insertReadOnlyOptions["lineNumbers"]= false;
            newScope.insertReadOnlyOptions["foldGutter"]= false;

            CreateModalFromTemplate("templates/analysis/prediction/insert-special-input-modal.html",
                newScope,
                null,
                function(scope) {

                    scope.acceptDeferred = deferred;

                    scope.uiState.insertInput = input;
                    scope.uiState.insertFeature = Object.keys(scope.perFeature)
                        .find(featName => scope.perFeature[featName]["sendToInput"] === input);
                    scope.uiState.insertFeatParams = scope.perFeature[scope.uiState.insertFeature];
                    scope.uiState.insertStartInputCode = "input_" + input + " = Input(shape=";
                    scope.uiState.insertEndInputCode = ", name=\""+input+"\")";

                    scope.insertSpecialInput = function () {
                        const inputCode = scope.uiState.insertStartInputCode + scope.uiState.insertInputShape + scope.uiState.insertEndInputCode;
                        scope.acceptDeferred.resolve(inputCode);
                        scope.dismiss();
                    };

                    scope.showHideProcessorCode = function() {
                        scope.uiState.processorCodeShown = !scope.uiState.processorCodeShown;
                    };
                });
            deferred.promise.then(function(inputCode) {
                insertCode(inputCode);
            });
        }
    };

    $scope.showHideInput = function(input) {
        inputsShown[input] = ! inputsShown[input];
    };

    $scope.isShown = function(input) {
        return inputsShown[input];
    };

    $scope.getNumFeatures = function(input) {
        return Object.values($scope.mlTaskDesign.preprocessing.per_feature)
            .filter(function(p) { return p["sendToInput"] === input && p["role"] === "INPUT" ;})
            .length
    };

    $scope.filterFeatures = function(input) {
        return function(feat) {
            return feat.sendToInput === input && feat.role === "INPUT";
        };
    };

    $scope.isSpecialInput = function(input) {
        return $scope.SettingsService.isSpecialInput(input, $scope.mlTaskDesign.preprocessing.per_feature);
    };

    function getSpecialInputType(input) {
        const specialFeature = Object.values($scope.mlTaskDesign.preprocessing.per_feature)
            .find(function(p) { return p["sendToInput"] === input && p["role"] === "INPUT" ;});
        return specialFeature.type;
    }

    $scope.getSpecialInputIcon = function(input) {
        const specialInputType = getSpecialInputType(input);
        let iconClass;
        if (specialInputType === "TEXT") {
            iconClass = "icon-italic";
        } else if (specialInputType === "IMAGE") {
            iconClass = "icon-picture";
        } else {
            iconClass = "";
        }
        return iconClass;
    };

    $scope.isMainInput = function(input) {
        return input === 'main';
    };

    $scope.isEditable = function(input) {
        return (!$scope.isSpecialInput(input) && !$scope.isMainInput(input));
    };

    $scope.getEditTitle = function(input) {
        if ($scope.isMainInput(input)) {
            return "Main input cannot be edited"
        } else if ($scope.isSpecialInput(input)) {
            return "Special input cannot be edited"
        } else {
            return "Edit input"
        }
    };

    $scope.getInsertTitle = function(input) {
        if ($scope.getNumFeatures(input) === 0) {
            return "Empty Input cannot be inserted";
        } else {
            return "Insert";
        }
    };

    $scope.isInsertable = function(input) {
        return $scope.getNumFeatures(input) > 0;
    };

    function getTextNetworkAndAddInput(input, inputInNetworks) {
        let inputVarName = "input_" + input;
        const feature = Object.keys($scope.mlTaskDesign.preprocessing.per_feature)
                              .find(featName => $scope.mlTaskDesign.preprocessing.per_feature[featName]["sendToInput"] === input);
        inputInNetworks.push(inputVarName);
        return "    # This input will receive preprocessed text from '" + feature + "' column\n" +
               "    " + inputVarName + " = Input(shape=(32,), name=\""+input+"\")\n" +
               "    x_" + input + " = Embedding(output_dim=512, input_dim=10000, input_length=32)(input_" + input + ")\n" +
               "    x_" + input + " = Reshape((32 * 512,))(x_" + input + ")\n\n";
    }

    $scope.fillBuildCodeKeras = function(keepAndCommentPrevious) {
        if (keepAndCommentPrevious || $scope.mlTaskDesign.modeling.keras.buildCode === undefined) {

            let predictionLine;
            let lossFunction;
            let problemType;
            if ($scope.mlTaskDesign.predictionType === "REGRESSION") {
                predictionLine = "    predictions = Dense(1)(x)";
                lossFunction = "mse";
                problemType = "regression";
            } else {
                predictionLine = "    predictions = Dense(n_classes, activation='softmax')(x)";
                if ($scope.mlTaskDesign.predictionType === "BINARY_CLASSIFICATION") {
                    lossFunction = "binary_crossentropy";
                    problemType = "binary classification";
                } else {
                    lossFunction = "categorical_crossentropy";
                    problemType = "multiclass classification";
                }
            }

            // Retrieve Text special inputs that may have been guessed
            const specialTextInputNames = $scope.mlTaskDesign.modeling.keras.kerasInputs.filter(x => $scope.isSpecialInput(x) && getSpecialInputType(x) === "TEXT");
            const hasSpecialTextInputs = specialTextInputNames.length >= 1;
            const hasMain = $scope.getNumFeatures("main") >= 1;
            const numRealInputs = $scope.mlTaskDesign.modeling.keras.kerasInputs.length - (hasMain ? 0 : 1);
            let actualInputsInNetwork = [];

            let startNetwork = "";
            let lastLayerSoFar;
            if (hasMain || numRealInputs === 0) {
                const mainInputVarName ="input_main";
                startNetwork += '    # This input will receive all the preprocessed features\n' +
                                '    # sent to \'main\'\n' +
                                '    ' + mainInputVarName +' = Input(shape=input_shapes["main"], name="main")\n\n';
                lastLayerSoFar = mainInputVarName;
                actualInputsInNetwork.push(mainInputVarName);
            }
            if (hasSpecialTextInputs) {
                specialTextInputNames.forEach( input => {
                    startNetwork += getTextNetworkAndAddInput(input, actualInputsInNetwork);
                    lastLayerSoFar = "x_" + input;
                });
            }
            if (numRealInputs > 1) {
                const concatLayers = [];
                if (hasMain) {
                    concatLayers.push("input_main");
                }
                specialTextInputNames.forEach( input => {
                    concatLayers.push("x_" + input);
                })
                startNetwork += "    x = concatenate([" + concatLayers.join(", ") + "])\n\n";
                lastLayerSoFar = "x";
            }

            const importPrefix = VisualMlCodeEnvCompatibility.isEnvAtLeastTensorflow2_2($scope.mlTaskDesign, $scope.codeEnvsCompat) ? "from tensorflow.keras." : "from keras."
            let layerImportLine = importPrefix + "layers import Input, Dense";
            if (hasSpecialTextInputs) {
                layerImportLine += ", Embedding, Reshape";
            }
            if (numRealInputs > 1) {
                layerImportLine += ", concatenate"
            }

            // language=Python
            let buildCode = layerImportLine + "\n" +
                              importPrefix + "models import Model\n\n" +
                              "# Define the keras architecture of your model in 'build_model' and return it. Compilation must be done in 'compile_model'.\n" +
                              "#   input_shapes  - dictionary of shapes per input as defined in features handling\n" +
                              "#   n_classes - For classification, number of target classes\n" +
                              "def build_model(input_shapes, n_classes=None):\n\n" +
                              startNetwork +
                              "    x = Dense(64, activation='relu')(" + lastLayerSoFar + ")\n" +
                              "    x = Dense(64, activation='relu')(x)\n" +
                              "\n" +
                              predictionLine + "\n" +
                              "\n" +
                              "    # The 'inputs' parameter of your model must contain the\n" +
                              "    # full list of inputs used in the architecture\n" +
                              "    model = Model(inputs=[" + actualInputsInNetwork.join(", ") + "], outputs=predictions)\n" +
                              "\n" +
                              "    return model\n" +
                              "\n" +
                              "# Compile your model and return it\n" +
                              "#   model   - model defined in 'build_model'\n" +
                              "def compile_model(model):\n" +
                              "    \n" +
                              "    # The loss function depends on the type of problem you solve.\n" +
                              "    # '" + lossFunction + "' is appropriate for a " + problemType + ".\n" +
                              "    model.compile(optimizer='rmsprop',\n" +
                              "                  loss='" + lossFunction + "')\n" +
                              "\n" +
                              "    return model";

            if (keepAndCommentPrevious && $scope.mlTaskDesign.modeling.keras.buildCode) {
                 buildCode = buildCode + "\n\n### PREVIOUS CODE\n" +
                    $scope.mlTaskDesign.modeling.keras.buildCode.replace(/^/gm, '# ');
            }

            $scope.mlTaskDesign.modeling.keras.buildCode = buildCode;
            $scope.saveSettings();
        }
    };

    $scope.switchFitMode = function() {
        if (!$scope.mlTaskDesign.modeling.keras.advancedFitMode) {
            fillFitCodeKeras();
        }
        $scope.mlTaskDesign.modeling.keras.advancedFitMode = !$scope.mlTaskDesign.modeling.keras.advancedFitMode;
    };

    $scope.editorOptions = CodeBasedEditorUtils.editorOptions('text/x-python', $scope, true);
    $scope.validateRecipe = function () {
        const deferred = $q.defer();
        try {
            $scope.runningValidation = true;
            DataikuAPI.analysis.pml.validateArchitecture(
                $scope.mlTaskDesign.modeling.keras.buildCode,
                $scope.mlTaskDesign.envSelection,
                $stateParams.projectKey
            ).success(data => {
                $scope.valCtx.showPreRunValidationError = false;
                $scope.valCtx.validationResult = data;
                deferred.resolve(data);
            }).error(setErrorInScope.bind($scope));
        } finally {
            $scope.runningValidation = false;
        }
        return deferred.promise;
    };

    $scope.gotoLine = function(cm, line) {
        if(cm && line>0) {
            var pos = {ch:0,line:line-1};
            cm.scrollIntoView(pos);
            cm.setCursor(pos);
            cm.focus();
        }
    };

    $scope.$watch('mlTaskDesign', (nv) => {
        if (nv) {
            $scope.fillBuildCodeKeras();

            // Display Inputs area by default if there are more than one input
            if ($scope.mlTaskDesign && $scope.mlTaskDesign.modeling.keras && $scope.mlTaskDesign.modeling.keras.kerasInputs.length > 1) {
                $scope.uiState.displayInput = true;
            }
        }
    });

    // Required for architecture code validation:
    $scope.valCtx = {};
    $scope.recipe = {type:'python', params: {}}

});

app.controller("PMLTaskBaseController", function($scope,$controller, DataikuAPI, PMLSettings, PMLFilteringService, CustomMetricIDService, $stateParams, Dialogs, Fn, Logger, WT1,
               CodeMirrorSettingService, Assert, FeatureFlagsService, MLModelsUIRouterStates, GPU_SUPPORTING_CAPABILITY, GpuUsageService, MLContainerInfoService,
                                                 MLTaskInformationService, TimeseriesFeatureGenerationService) {
    $scope.MLAPI = DataikuAPI.analysis.pml;
    $scope.FilteringService = PMLFilteringService;
    $scope.SettingsService = PMLSettings;
    $scope.codeMirrorSettingService = CodeMirrorSettingService;
    $scope.CustomMetricIDService = CustomMetricIDService;
    $scope.sRefPrefix = 'projects.project.analyses.analysis.ml.predmltask';
    $scope.uiSplitParams = {};
    $scope.getPredictionDesignTabPrefix = () => MLModelsUIRouterStates.getPredictionDesignTabPrefix($scope);

    // to be run if not guessing
    $scope.initMlTaskDesign = function() {
        return DataikuAPI.analysis.pml.getUpdatedSettings($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId).then(({data}) => {
            $scope.setMlTaskDesign(data);
            $scope.fillUISplitParams($scope.mlTaskDesign.splitParams);
            $scope.savedSettings = dkuDeepCopy($scope.mlTaskDesign, PMLSettings.noDollarKey);
        })
        .then($scope.refreshStatus)
        .catch(setErrorInScope.bind($scope));
    }

    $scope.fillUISplitParams = function(splitParams) {
        if (!splitParams) {
            throw new Error('No split params');
        }
        if (splitParams.ttPolicy == 'SPLIT_SINGLE_DATASET') {
            if (splitParams.ssdDatasetSmartName == null) {
                $scope.uiSplitParams.policy = "SPLIT_MAIN_DATASET";
            } else {
                $scope.uiSplitParams.policy = "SPLIT_OTHER_DATASET";
            }
        } else if (splitParams.ttPolicy == "EXPLICIT_FILTERING_SINGLE_DATASET") {
            if (splitParams.efsdDatasetSmartName == null) {
                $scope.uiSplitParams.policy = "EXPLICIT_FILTERING_SINGLE_DATASET_MAIN";
            } else {
                $scope.uiSplitParams.policy = "EXPLICIT_FILTERING_SINGLE_DATASET_OTHER";
            }
        } else {
            $scope.uiSplitParams.policy = splitParams.ttPolicy;
        }
    };

    $scope.isClassification = function() {
        if (!$scope.mlTaskDesign) return false;
        return PMLSettings.task.isClassification($scope.mlTaskDesign.predictionType);
    }

    $scope.h2oEnabled = function() {
        if (!$scope.appConfig) return false;
        return $scope.appConfig.h2oEnabled;
    }

    $scope.isRegression = function() {
        if (!$scope.mlTaskDesign) return false;
        return ["REGRESSION", "CAUSAL_REGRESSION"].includes($scope.mlTaskDesign.predictionType);
    }

    $scope.isBinaryClassification = function() {
        if (!$scope.mlTaskDesign) return false;
        return ["BINARY_CLASSIFICATION", "CAUSAL_BINARY_CLASSIFICATION"].includes($scope.mlTaskDesign.predictionType);
    }

    $scope.isMulticlass = function() {
        if (!$scope.mlTaskDesign) return false;
        return $scope.mlTaskDesign.predictionType == "MULTICLASS";
    }

    $scope.isClassicalPrediction = function() {
        return $scope.mlTaskDesign && ["REGRESSION", "BINARY_CLASSIFICATION", "MULTICLASS"].includes($scope.mlTaskDesign.predictionType);
    }

    $scope.isTimeseriesPrediction = function() {
        if (!$scope.mlTaskDesign) return false;
        return $scope.mlTaskDesign.predictionType === "TIMESERIES_FORECAST";
    }

    $scope.isGpuCompatibleTimeseriesPrediction = function() {
        if (!$scope.mlTaskDesign || !$scope.mlTaskDesign.modeling) return false;

        return $scope.isTimeseriesPrediction() &&
            (
                $scope.mlTaskDesign.modeling.gluonts_deepar_timeseries.enabled ||
                $scope.mlTaskDesign.modeling.gluonts_mqcnn_timeseries.enabled ||
                $scope.mlTaskDesign.modeling.gluonts_simple_feed_forward_timeseries.enabled ||
                $scope.mlTaskDesign.modeling.gluonts_transformer_timeseries.enabled ||
                ($scope.mlTaskDesign.modeling.gluonts_torch_simple_feed_forward_timeseries
                    && $scope.mlTaskDesign.modeling.gluonts_torch_simple_feed_forward_timeseries.enabled) ||
                ($scope.mlTaskDesign.modeling.gluonts_torch_deepar_timeseries
                    && $scope.mlTaskDesign.modeling.gluonts_torch_deepar_timeseries.enabled)
            )

    }

    $scope.supportsCustomMetrics = function() {
        return $scope.backendIsPythonBased();
    }

    $scope.isPartitioned = function() {
        return $scope.mlTaskDesign.partitionedModel
            && $scope.mlTaskDesign.partitionedModel.enabled
    }

    $scope.supportsUncertainty = function() {
        return $scope.isMLBackendType('PY_MEMORY') &&
            $scope.isRegression() &&
            !$scope.isPartitioned()
    }

    $scope.isCausalPrediction = function() {
        if (!$scope.mlTaskDesign) return false;
        return $scope.mlTaskDesign.predictionType === "CAUSAL_REGRESSION" || $scope.mlTaskDesign.predictionType === "CAUSAL_BINARY_CLASSIFICATION";
    }

    $scope.isDeepHubPrediction = function() {
        if (!$scope.mlTaskDesign) return false;
        return $scope.isMLBackendType("DEEP_HUB");
    }

    $scope.isPreprocessingUsed = function(checkHandlingFct){
        if (!$scope.mlTaskDesign
            || !$scope.mlTaskDesign.preprocessing
            || !$scope.mlTaskDesign.preprocessing.per_feature) {
            return false;
        }
        return Object.values($scope.mlTaskDesign.preprocessing.per_feature).some(featPreproc =>
            featPreproc.role === "INPUT" && checkHandlingFct(featPreproc));
    }
    $scope.usesCodeEnvSentenceEmbedding = function() {
        return $scope.isPreprocessingUsed((featPreproc) => featPreproc.type == "TEXT" && featPreproc.text_handling === "SENTENCE_EMBEDDING" && featPreproc.sentenceEmbeddingModel != null && !featPreproc.isStructuredRef);
    }
    $scope.usesLLMmeshPreprocessing = function() {
        return $scope.isPreprocessingUsed((featPreproc) => (featPreproc.type == "TEXT" && featPreproc.text_handling === "SENTENCE_EMBEDDING" && featPreproc.isStructuredRef) ||
                                                           (featPreproc.type == "IMAGE" && featPreproc.image_handling === "EMBEDDING_EXTRACTION"));
    }

    $scope.isDeepNeuralNetworkEnabled = function() {
        if (!$scope.mlTaskDesign
            || !$scope.mlTaskDesign.modeling) {
            return false;
        }
        // classification and regression handled separately since classification was added after regression
        // and only one of the two can be enabled
        return ($scope.mlTaskDesign.modeling.deep_neural_network_regression && $scope.mlTaskDesign.modeling.deep_neural_network_regression.enabled)
            || ($scope.mlTaskDesign.modeling.deep_neural_network_classification && $scope.mlTaskDesign.modeling.deep_neural_network_classification.enabled);
    };

    $scope.isXGBoostEnabled = function() {
        if (!$scope.mlTaskDesign
            || !$scope.mlTaskDesign.modeling) {
            return false;
        }
        return ($scope.mlTaskDesign.modeling.xgboost && $scope.mlTaskDesign.modeling.xgboost.enabled);
    };

    $scope.isMetricSupportedByXGBoostForEarlyStopping = function() {
        return $scope.isXGBoostEnabled()
            && (!$scope.mlTaskDesign.modeling.xgboost.enable_early_stopping
                || ['RMSE', 'MAE'].includes($scope.mlTaskDesign.modeling.metrics.evaluationMetric));
    };

    $scope.canUseGpu = function() {
        return $scope.isMLBackendType('KERAS') ||
            $scope.isMLBackendType('DEEP_HUB') ||
            $scope.isMLBackendType('PY_MEMORY');
    };

    $scope.getTrainingGPUParams = function() {
        return $scope.sessionTask.gpuConfig.params;
    }

    $scope.inContainer = MLContainerInfoService.inContainer($scope, $stateParams.projectKey);

    $scope.isCurrentSessionRunning = function() {
        const currentSessionId = (($scope.selection && $scope.selection.sessionModels && $scope.selection.sessionModels[0]) || {}).sessionId;
        if (!currentSessionId) return false;

        return MLTaskInformationService.isSessionRunning($scope, currentSessionId);
    }

    $scope.shouldShowGpuUsage = function() {
        if ($scope.mlTaskStatus.training && $scope.isCurrentSessionRunning()) {
            return $scope.getTrainingGPUParams() && $scope.getTrainingGPUParams().useGpu && !$scope.inContainer($scope.sessionTask.containerSelection);
        }

        return false;
    }

    $scope.isGridSearchStrategy = function() {
        if (!($scope.mlTaskDesign)) return true;
        return $scope.mlTaskDesign.modeling.gridSearchParams.strategy === 'GRID';
    }


    $scope.base_algorithms = {
        PY_MEMORY: [
            {name:'Random Forest', algKey:'random_forest_classification', condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Random Forest', algKey:'random_forest_regression', condition: function() {return $scope.isRegression() || $scope.isTimeseriesPrediction()}, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Gradient tree boosting', algKey:'gbt_classification', condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Gradient tree boosting', algKey:'gbt_regression', condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Logistic Regression', algKey:'logistic_regression',condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Ordinary Least Squares', algKey:'leastsquare_regression',condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Ridge Regression', algKey:'ridge_regression',condition: function() {return $scope.isRegression() || $scope.isTimeseriesPrediction()}, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Lasso Regression', algKey:'lasso_regression',condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name: 'LightGBM', algKey: 'lightgbm_classification', condition: $scope.isClassification, supportedCausalMethod: "META_LEARNER", templateName: 'lightgbm', type: 'classicalML', displayGroups: ['classicalML']},
            {name: 'LightGBM', algKey: 'lightgbm_regression', condition: function() {return $scope.isRegression() || ($scope.isTimeseriesPrediction())}, supportedCausalMethod: "META_LEARNER", templateName: 'lightgbm', type: 'classicalML', displayGroups: ['classicalML']},

            {name:'XGBoost', algKey:'xgboost_classification', condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", templateName: 'xgboost', hpSpaceName: 'xgboost', type: 'classicalML', displayGroups: ['classicalML']},
            {name:'XGBoost', algKey:'xgboost_regression', condition: function() {return $scope.isRegression() || ($scope.isTimeseriesPrediction())}, supportedCausalMethod: "META_LEARNER", templateName: 'xgboost', hpSpaceName: 'xgboost', type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Decision Tree', algKey:'decision_tree_classification',condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Decision Tree', algKey:'decision_tree_regression',condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Support Vector Machine', algKey:'svc_classifier',condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Support Vector Machine', algKey:'svm_regression',condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Stochastic Gradient Descent', algKey:'sgd_classifier',condition:$scope.isClassification, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Stochastic Gradient Descent', algKey:'sgd_regression',condition:$scope.isRegression, supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name:'KNN', algKey:'knn', condition:Fn.not($scope.isTimeseriesPrediction), supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Extra Random Trees', algKey:'extra_trees', condition:Fn.not($scope.isTimeseriesPrediction), supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Single Layer Perceptron', algKey:'neural_network', condition:Fn.not($scope.isTimeseriesPrediction), supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Lasso Path', algKey:'lars_params', condition:Fn.not($scope.isTimeseriesPrediction), supportedCausalMethod: "META_LEARNER", type: 'classicalML', displayGroups: ['classicalML']},

            {name: 'Deep Neural Network', algKey: 'deep_neural_network_classification', condition: function() {return $scope.isClassification()}, supportedCausalMethod: "META_LEARNER", templateName: 'deep_neural_network', type: 'classicalML', displayGroups: ['classicalML']},
            {name: 'Deep Neural Network', algKey: 'deep_neural_network_regression', condition: function() {return $scope.isRegression()}, supportedCausalMethod: "META_LEARNER", templateName: 'deep_neural_network', type: 'classicalML', displayGroups: ['classicalML']},

            {name:'Deep Learning (H2O)', algKey:'deep_learning_h2o',condition:$scope.h2oEnabled, type: 'classicalML', displayGroups: ['classicalML']},
            {name:'GLM (H2O)',  algKey:'glm_h2o',condition:function(){return $scope.h2oEnabled()&&$scope.isRegression()}, type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Gradient Boosting (H2O)',  algKey:'gbm_h2o',condition:$scope.h2oEnabled, type: 'classicalML', displayGroups: ['classicalML']},
            {name:'Random Forest (H2O)',  algKey:'distributed_rf_h2o',condition:$scope.h2oEnabled, type: 'classicalML', displayGroups: ['classicalML']},

            // Timeseries - Statistical
            {name:'AutoARIMA', algKey:'autoarima_timeseries', algEnumName:'auto_arima', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'ARIMA', algKey:'arima_timeseries', algEnumName:'arima', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'Croston', algKey:'croston_timeseries', algEnumName:'croston', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'Seasonal trend', algKey:'seasonal_loess_timeseries', algEnumName:'seasonal_loess', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'ETS', algKey:'ets_timeseries', algEnumName: 'ets', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'Prophet', algKey: 'prophet_timeseries', algEnumName:'prophet', condition: $scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},
            {name:'NPTS', algKey:'gluonts_npts_timeseries', algEnumName:'gluonts_npts_forecaster', condition:$scope.isTimeseriesPrediction, type:"statistical", displayGroups: ['statistical']},

            // Timeseries - DL
            {name:'Simple Feed Forward - Torch', algKey:'gluonts_torch_simple_feed_forward_timeseries', algEnumName:'gluonts_torch_simple_feedforward', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['deepLearning']}, // with torch
            {name:'DeepAR - Torch', algKey:'gluonts_torch_deepar_timeseries', algEnumName:'gluonts_torch_deepar', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['deepLearning']}, // with torch

            // Timeseries - Baseline
            {name:'Trivial identity', algKey:'trivial_identity_timeseries', condition:$scope.isTimeseriesPrediction, type:"baseline", displayGroups: ['baseline']},
            {name:'Seasonal naive', algKey:'seasonal_naive_timeseries', algEnumName:'seasonal_naive', condition:$scope.isTimeseriesPrediction, type:"baseline", displayGroups: ['baseline']},

            // Timeseries - Legacy
            {name:'Simple Feed Forward - MXNet', algKey:'gluonts_simple_feed_forward_timeseries', algEnumName:'gluonts_simple_feedforward', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['legacy']},
            {name:'DeepAR - MXNet', algKey:'gluonts_deepar_timeseries', algEnumName:'gluonts_deepar', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['legacy']},
            {name:'Transformer', algKey:'gluonts_transformer_timeseries', algEnumName:'gluonts_transformer', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['legacy']},
            {name:'MQ-CNN', algKey:'gluonts_mqcnn_timeseries', algEnumName:'gluonts_mqcnn', condition:$scope.isTimeseriesPrediction, type:"deep_learning", displayGroups: ['legacy']},

            // Causal Predictions
            {name:'Causal Forest', algKey:'causal_forest', condition: $scope.isCausalPrediction, supportedCausalMethod: "CAUSAL_FOREST", type: 'classicalML', displayGroups: ['classicalML']},

        ],
        MLLIB: [
            {name:'Linear Regression', algKey:'mllib_linreg', algEnumName:'mllib_linear_regression',condition:$scope.isRegression},
            {name:'Logistic Regression', algKey:'mllib_logit', algEnumName:'mllib_logistic_regression',condition:$scope.isClassification},
            {name:'Decision Tree', algKey:'mllib_dt', algEnumName:'mllib_decision_tree'},
            {name:'Random Forest', algKey:'mllib_rf', algEnumName:'mllib_random_forest'},
            {name:'Gradient tree boosting', algKey:'mllib_gbt', algEnumName:'mllib_gbt',condition:Fn.not($scope.isMulticlass)},
            {name:'Naive Bayes', algKey:'mllib_naive_bayes', algEnumName:'mllib_naive_bayes',condition:$scope.isMulticlass},
        ],
        H2O :[
            {name:'Deep Learning', algKey:'deep_learning_sparkling', algEnumName:'sparkling_deep_learning'},
            {name:'Generalized Linear Model', algKey:'glm_sparkling', algEnumName:'sparkling_glm'},
            {name:'Gradient Boosting', algKey:'gbm_sparkling', algEnumName:'sparkling_gbm'},
            {name:'Random Forest', algKey:'rf_sparkling', algEnumName:'sparkling_rf'},
            {name:'Naive Bayes', algKey:'nb_sparkling', algEnumName:'sparkling_nb',condition:$scope.isClassification},
        ],
        KERAS: [
            {name: "Deep Learning with Keras", algKey: "keras"}
        ],
        DEEP_HUB: [
            {name: "Computer vision", algKey: "deephub-computer-vision"}
        ]
    };


    $scope.setAdditionalSnippetParams = function() {
        if (!$scope.modelSnippets || !$scope.sessionTask) return;
        if (!$scope.isTimeseriesPrediction()) return;
        if ($scope.sessionTask.partitionedModel && $scope.sessionTask.partitionedModel.enabled) return;
        if ($scope.sessionTask.$forecastRange && $scope.sessionTask.$forecastRange.final) return;

        const doneSessionSnippets = Object.values($scope.modelSnippets)
            .filter(snippet => snippet.trainInfo.state === 'DONE' && snippet.sessionId === $scope.sessionTask.sessionId);

        if (doneSessionSnippets.length < 2) return;

        $scope.sessionTask.$forecastRange = {
            final: !$scope.mlTaskStatus.training,
            ...doneSessionSnippets.reduce(function(currentMinMax, snippet) {
                const firstTimeseries = Object.values(snippet.forecasts.perTimeseries)[0];

                const sortedQuantiles = firstTimeseries.quantiles.sort((a, b) => a.quantile - b.quantile);
                const [ lowerQuantile, upperQuantile ] =  [
                    sortedQuantiles[0], sortedQuantiles[sortedQuantiles.length - 1]
                ];

                if (lowerQuantile.quantile !== upperQuantile.quantile) {
                    currentMinMax.min = Math.min(currentMinMax.min, ...lowerQuantile.forecast);
                    currentMinMax.max = Math.max(currentMinMax.max, ...upperQuantile.forecast);
                }

                // In snippets, we only display ground truth data for the timestamps that also have backtest data + two
                // extra timestamps in case we have too few data points; see how we define actualData in directive
                // timeseriesForecastingGraphs when largeContainer = false
                const displayedGroundTruth = firstTimeseries.groundTruth.slice(-firstTimeseries.forecast.length - 2);

                return {
                    min: Math.min(...displayedGroundTruth, ...firstTimeseries.forecast, currentMinMax.min),
                    max: Math.max(...displayedGroundTruth, ...firstTimeseries.forecast, currentMinMax.max)
                };
            }, { min: Number.MAX_VALUE, max: - Number.MAX_VALUE })
        };
    }

    // This must be called after the above functions:
    // it expects initMlTaskDesign(), base_algorithms, ... and others to be in $scope,
    // so errors will be raised if called earlier
    $controller("_MLTaskBaseController", {$scope: $scope});

    // Train button
    $scope.dumpUISplitParams = function(){
        const sp = $scope.mlTaskDesign.splitParams;
        if (!sp) {
            throw new Error('No split params');
        }
        if ($scope.uiSplitParams.policy == "SPLIT_MAIN_DATASET") {
            sp.ttPolicy = "SPLIT_SINGLE_DATASET";
            sp.ssdDatasetSmartName = null;
        } else if ($scope.uiSplitParams.policy == "SPLIT_OTHER_DATASET") {
            sp.ttPolicy = "SPLIT_SINGLE_DATASET";
        } else if ($scope.uiSplitParams.policy == "EXPLICIT_FILTERING_SINGLE_DATASET_MAIN") {
            sp.ttPolicy = "EXPLICIT_FILTERING_SINGLE_DATASET";
            sp.efsdDatasetSmartName = null;
        } else if ($scope.uiSplitParams.policy == "EXPLICIT_FILTERING_SINGLE_DATASET_OTHER") {
            sp.ttPolicy = "EXPLICIT_FILTERING_SINGLE_DATASET";
        } else {
            sp.ttPolicy = $scope.uiSplitParams.policy;
        }
        Logger.info("DUMP UI SPLIT", sp, $scope.uiSplitParams);
    };

    $scope.saveSettings = function() {
        Assert.inScope($scope, "mlTaskDesign");

        $scope.dumpUISplitParams();

        TimeseriesFeatureGenerationService.cleanUpSettings($scope.mlTaskDesign);

        return DataikuAPI.analysis.pml.saveSettings($stateParams.projectKey, $stateParams.analysisId, $scope.mlTaskDesign)
            .then(({data}) => {
                resetErrorInScope($scope);
                $scope.savedSettings = dkuDeepCopy($scope.mlTaskDesign, PMLSettings.noDollarKey);
            })
            .then($scope.refreshStatus)
            .then($scope.listMLTasks)
            .catch(setErrorInScope.bind($scope));
    };

    $scope.checkSplitParams = function(splitParams, checkSingle) {
        if (!splitParams) {
            throw new Error('No split params');
        }
        var error = null;
        if (splitParams.ttPolicy === 'EXPLICIT_FILTERING_TWO_DATASETS') {
            if (!splitParams.eftdTest || !splitParams.eftdTest.datasetSmartName) {
                error = 'No test dataset specified.';
            }
            if (!splitParams.eftdTrain || !splitParams.eftdTrain.datasetSmartName) {
                error = error ? 'No train nor test dataset specified.' : 'No train dataset specified.';
            }
        } else if (checkSingle) { // not in settings, so EFSD_MAIN / SPLIT_MAIN should have filled the dataset
            if (    ('ssdDatasetSmartName' in splitParams && !splitParams.ssdDatasetSmartName)
                 || ('efsdDatasetSmartName' in splitParams && !splitParams.efsdDatasetSmartName)) {
                error = 'No dataset specified.';
            }
        } else if ( ($scope.uiSplitParams.policy === "SPLIT_OTHER_DATASET" && !splitParams.ssdDatasetSmartName)
                    || ($scope.uiSplitParams.policy === "EXPLICIT_FILTERING_SINGLE_DATASET_OTHER" && !splitParams.efsdDatasetSmartName)) {
            error = 'No dataset specified.';    // in settings + in explicit dataset
        }
        if (error) {
            Dialogs.ack($scope, 'Incorrect Train/Test settings', error);
            return false;
        }
        return true;
    };

    $scope.isSampleWeightEnabled = function() {
        const weightMethod = $scope.mlTaskDesign.weight && $scope.mlTaskDesign.weight.weightMethod;
        return weightMethod === 'SAMPLE_WEIGHT' || weightMethod === 'CLASS_AND_SAMPLE_WEIGHT';
    };

    $scope.multiclassAveragingAndWeightingInconsistent = function() {
        const weightMethod = $scope.mlTaskDesign.weight && $scope.mlTaskDesign.weight.weightMethod;
        const classAveragingMethod = $scope.mlTaskDesign.modeling.metrics.classAveragingMethod;
        if (['CLASS_WEIGHT', 'CLASS_AND_SAMPLE_WEIGHT'].includes(weightMethod)) {
            return classAveragingMethod === "WEIGHTED";
        }
        else if (['NO_WEIGHTING', 'SAMPLE_WEIGHT'].includes(weightMethod)) {
            return classAveragingMethod === "MACRO";
        } else return false;
    };

    $scope.getAvailableGpuCapabilities = function() {
        return GpuUsageService.getAvailableGpuCapabilities($scope.mlTaskDesign.backendType, $scope.mlTaskDesign.predictionType);
    };


    $scope.getUsedGpuCapabilities = function() {
        const gpuCapabilities = [];

        // GPU for training
        if ($scope.isMLBackendType('KERAS')) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.KERAS);
        } else if ($scope.isGpuCompatibleTimeseriesPrediction()) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.GLUONTS);
        } else if ($scope.isMLBackendType('DEEP_HUB')) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.DEEP_HUB);
        }

        if ($scope.isDeepNeuralNetworkEnabled()) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.DEEP_NN);
        }

        if ($scope.isXGBoostEnabled()) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.XGBOOST);
        }

        // GPU for non-training-specific features
        if ($scope.usesCodeEnvSentenceEmbedding()) {
            gpuCapabilities.push(GPU_SUPPORTING_CAPABILITY.SENTENCE_EMBEDDING);
        }

        return gpuCapabilities;
    }

    $scope.getUsedGpuCapabilitiesWithGpuOn = function() {
        if (!$scope.mlTaskDesign) return []

        const usedGpuCapabilities = $scope.getUsedGpuCapabilities();

        const usedGpuCapabilitiesWithGpuOn = usedGpuCapabilities.filter(item => {
            return $scope.mlTaskDesign.gpuConfig.disabledCapabilities.indexOf(item) === -1;
        });

        return usedGpuCapabilitiesWithGpuOn;
    };
});

app.controller("_PMLTrainSessionController", function($scope, DataikuAPI, PMLSettings, CodeMirrorSettingService, $state, $stateParams,
                                                      CreateModalFromTemplate, MLContainerInfoService, WT1, ActivityIndicator) {
    $scope.newTrainSessionModalDisplayed = false;

    $scope.refreshMLTaskSessions = function(willBeQueued) {
        $scope.newTrainSessionModalDisplayed = false;
        $scope.uiState.$userRequestedState = false;
        $scope.initialRefreshAndAutoRefresh();
        if (!willBeQueued) {
            if (!$state.current.name.startsWith($scope.sRefPrefix + '.list.results')) {
                $state.go($scope.sRefPrefix + '.list.results.sessions');
            }
        } else {
            ActivityIndicator.success("Session added to queue");
        }
    };

    function newTrainSessionCallback(willBeQueued) {
        return DataikuAPI.analysis.pml.getUpdatedSettings($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId).then(function(response){
            if ($scope.checkSplitParams(response.data.splitParams, true)) {
                DataikuAPI.analysis.pml.saveSettings($stateParams.projectKey, $stateParams.analysisId, response.data).success(function(data){
                    $scope.savedSettings = dkuDeepCopy($scope.mlTaskDesign, PMLSettings.noDollarKey);
                    if (!$scope.newTrainSessionModalDisplayed) {
                        $scope.newTrainSessionModalDisplayed = true;
                        let createFromModal;
                        $scope.willBeQueued = willBeQueued;
                        if ($scope.isDeepHubPrediction()) {  // TODO @deephub: better factorize
                            createFromModal = CreateModalFromTemplate("/templates/analysis/prediction/deephub/pre-train-modal.html", $scope, "DeepHubPMLTaskPreTrainModal");
                        } else if ($scope.isTimeseriesPrediction()) {
                            createFromModal = CreateModalFromTemplate("/templates/analysis/prediction/pre-train-modal.html", $scope, "TimeseriesPMLTaskPreTrainModal");
                        } else if ($scope.isClassicalPrediction()) {
                            createFromModal = CreateModalFromTemplate("/templates/analysis/prediction/pre-train-modal.html", $scope, "ClassicalPMLTaskPreTrainModal");
                        } else if ($scope.isCausalPrediction()) {
                            createFromModal = CreateModalFromTemplate("/templates/analysis/prediction/pre-train-modal.html", $scope, "CausalPMLTaskPreTrainModal");
                        } else {
                            throw new Error('Unknown prediction type ' + $scope.mlTaskDesign.predictionType);
                        }
                        createFromModal.then(function() {
                            $scope.refreshMLTaskSessions(willBeQueued);
                        }, function(){
                            $scope.newTrainSessionModalDisplayed = false;
                        });
                    }
                }).error(setErrorInScope.bind($scope));
            }
        },setErrorInScope.bind($scope));
    }

    $scope.newTrainSession = function(willBeQueued) {
        if ($scope.dirtySettings()) {
            $scope.saveSettings().then(() => newTrainSessionCallback(willBeQueued));
        }
        else {
            newTrainSessionCallback(willBeQueued);
        }
    };

    $scope.trainQueue = function() {
        DataikuAPI.analysis.pml.trainQueue($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId)
            .success(() => {
                $scope.refreshMLTaskSessions();
            })
            .error((data, status, headers) => {
                setErrorInScope.bind($scope)(data, status, headers);
                $scope.listQueuedSessions();
            });

        WT1.event("mltask-train-queue", {
            taskType: $scope.mlTaskDesign.taskType
        });
    };

    $scope.inContainer = MLContainerInfoService.inContainer($scope, $stateParams.projectKey);
});

app.controller("ClassicalPMLTaskBaseController", function($scope, $controller, DataikuAPI, $stateParams,
                                                          CreateModalFromTemplate, Debounce, $q, $rootScope,
                                                          PMLSettings, AlgorithmsSettingsService, Dialogs, $filter,
                                                          CachedAPICalls) {
    $controller("_PMLTrainSessionController", { $scope: $scope });
    $controller("PMLTaskCrossvalController", { $scope: $scope });

    $scope.deferredAfterInitMlTaskDesign
    .then(() => CachedAPICalls.pmlGuessPolicies)
    .then(pmlGuessPolicies => {
        $scope.guessPolicies = pmlGuessPolicies.auto.concat(pmlGuessPolicies.expert)
            .filter(policy => policy !== undefined)
            .filter(policy => ![
                'ALGORITHMS', // useless (choose from all)
                'DEEP' // incompatible interface
            ].includes(policy.id));
        $scope.guessPolicies = $scope.prepareGuessPolicies($scope.guessPolicies);
    })
    .then(() => $scope.enrichBaseAlgorithmsWithPlugins)
    .then(() => {
        $scope.setAlgorithms($scope.mlTaskDesign);
        $scope.setSelectedAlgorithm(AlgorithmsSettingsService.getDefaultAlgorithm(
            $scope.mlTaskDesign,
            $scope.algorithms[$scope.mlTaskDesign.backendType]
        ));
        if ($scope.mlTaskDesign.backendType === 'KERAS') {
            $controller("DeepLearningPMLController", {$scope: $scope});
        }
    })
    .catch(setErrorInScope.bind($scope));

    $scope.enrichBaseAlgorithmsWithPlugins = DataikuAPI.analysis.pml.listCustomPythonAlgos($stateParams.projectKey).success(function (data) {
            // Add custom algorithms from plugins if they are not here
            data.map(alg => {
                return {
                    algKey: alg.pyPredAlgoType,
                    name: alg.desc.meta.label,
                    customInfo: alg,
                    pluginDesc: $rootScope.appConfig.loadedPlugins.find(plugin => plugin.id === alg.ownerPluginId),
                    condition: function() {
                        const regCond = alg.desc.predictionTypes.includes("REGRESSION") && $scope.isRegression();
                        const binCond = alg.desc.predictionTypes.includes("BINARY_CLASSIFICATION") && $scope.isBinaryClassification();
                        const multCond = alg.desc.predictionTypes.includes("MULTICLASS") && $scope.isMulticlass();
                        return regCond || binCond || multCond;
                    }
                }
            }).filter(alg => ! $scope.base_algorithms["PY_MEMORY"].map(_ => _.algKey).includes(alg.algKey))
              .forEach(alg => {
                  // Adding Custom algos to algo list
                  $scope.base_algorithms["PY_MEMORY"].push(alg);

                  // Adding custom algo without sample weights support to dedicated list
                  if (!alg.customInfo.desc.supportsSampleWeights) {
                      $scope.algosWithoutWeightSupport.add(alg.algKey);
                  }
              });
        }).error(setErrorInScope.bind($scope));

    $scope.algosWithoutWeightSupport = new Set(['lasso_regression', 'knn', 'neural_network', 'lars_params', 'deep_neural_network_regression', 'deep_neural_network_classification']);


    $scope.beforeUpdateSettingsCallback = function(settings) {
        $scope.fillUISplitParams(settings.splitParams);
    };

    // watchers & init

    // Weighting strategy
    $scope.setWeightOptions = function(){
        if($scope.isRegression()) {
            $scope.uiState.weightMethods = [['NO_WEIGHTING', 'No weighting'],
                                            ['SAMPLE_WEIGHT', 'Sample weights']];
        } else {
          $scope.uiState.weightMethods = [['NO_WEIGHTING', 'No weighting'],
                                          ['SAMPLE_WEIGHT', 'Sample weights'],
                                          ['CLASS_WEIGHT', 'Class weights'],
                                          ['CLASS_AND_SAMPLE_WEIGHT', 'Class and sample weights'],
                                        ];
        }
    }

    $scope.$watch('mlTaskDesign.predictionType', (nv) => {
        if (nv) {
            $scope.setWeightOptions();
        }
    });

    $scope.$watch('mlTaskDesign.partitionedModel', Debounce().withScope($scope).withDelay(300, 300).wrap((nv, ov) => {
        if (nv) {
            if (nv.enabled) {
                const sampleSelection = $scope.mlTaskDesign.splitParams.ssdSelection;
                const partitionSelection = nv.ssdSelection;

                const partitionModelSettings = {
                    partitionSelectionMethod: partitionSelection.partitionSelectionMethod,
                    selectedPartitions: partitionSelection.selectedPartitions,
                    latestPartitionsN: partitionSelection.latestPartitionsN
                }

                // set sample partition method to partition model partition method
                Object.assign(sampleSelection, partitionModelSettings);
            } else if (ov && ov.enabled && !nv.enabled) {
                // partitioned models unchecked
                $scope.mlTaskDesign.splitParams.ssdSelection.partitionSelectionMethod = 'ALL';
            }
        }
    }), true)

    $scope.classAveragingMethods = [['WEIGHTED', 'Weighted'], ['MACRO', 'Unweighted']];

    $scope.onChangeWeightMethod = function() {
        if (!$scope.isMulticlass()) {
            if(($scope.uiState.weightMethod==="NO_WEIGHTING" || $scope.uiState.weightMethod==="CLASS_WEIGHT")
                && $scope.mlTaskDesign.weight.sampleWeightVariable) {
                Dialogs.confirm(
                    $scope, "Removing weight variable",
                    `The former weight variable <strong>${$filter('escapeHtml')($scope.mlTaskDesign.weight.sampleWeightVariable)}</strong> will be set as a numerical input to the models.`
                ).then(onChangeWeightMethod, () => { $scope.uiState.weightMethod = $scope.mlTaskDesign.weight.weightMethod; })
            } else {
                onChangeWeightMethod();
            }
        } else {
            const warningChangeAveragingMethodToUnweighted = "The <strong> unweighted</strong> averaging method will be set for computing the one-vs-all multiclass metrics. " +
                "The metrics will be computed for each class and the global metrics will be their unweighted average.";
            const warningChangeAveragingMethodToWeighted = "The <strong>weighted</strong> averaging method will be set for computing the one-vs-all multiclass metrics. ";
            if ($scope.uiState.weightMethod === "NO_WEIGHTING" || $scope.uiState.weightMethod === "CLASS_WEIGHT") {
                // The user chose a weighting strategy without sample weights
                // The following warning is used when the class average method needs to be changed. The NO_WEIGHTING strategy is consistent with the weighted class average and the CLASS_WEIGHT strategy is consistent with unweighted class average.
                const shouldDisplayWarningOnMetricsWeightingStrategy = ($scope.uiState.weightMethod === "NO_WEIGHTING" && $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "MACRO") ||
                    ($scope.uiState.weightMethod === "CLASS_WEIGHT" && $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "WEIGHTED");
                if ($scope.mlTaskDesign.weight.sampleWeightVariable) {
                    // Sample weight are specified, need to remove them and warn the user
                    Dialogs.confirm(
                        $scope, "Removing weight variable",
                        `The former weight variable <strong>${$filter('escapeHtml')($scope.mlTaskDesign.weight.sampleWeightVariable)}</strong> will be set as a numerical input to the models. <br> <br>`
                        + (shouldDisplayWarningOnMetricsWeightingStrategy ? ($scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "MACRO" ? warningChangeAveragingMethodToWeighted + "The metrics will be computed for each class and the global metrics will be their average weighted by the sum of the sample weights."
                            : warningChangeAveragingMethodToUnweighted) : "")
                    ).then(onChangeWeightMethod, () => { $scope.uiState.weightMethod = $scope.mlTaskDesign.weight.weightMethod; })
                } else {
                    if (shouldDisplayWarningOnMetricsWeightingStrategy) {
                        Dialogs.confirm(
                            $scope, "Changing averaging method for metrics",
                            $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "MACRO" ? warningChangeAveragingMethodToWeighted + "The metrics will be computed for each class and the global metrics will be will be their average weighted by their cardinality." : warningChangeAveragingMethodToUnweighted
                        ).then(onChangeWeightMethod, () => { $scope.uiState.weightMethod = $scope.mlTaskDesign.weight.weightMethod; })
                    } else {
                        onChangeWeightMethod();
                    }
                }
            } else {
                // The user chose a weighting strategy with sample weights
                const shouldDisplayWarningOnMetricsWeightingStrategy = ($scope.uiState.weightMethod === "CLASS_AND_SAMPLE_WEIGHT" && $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "WEIGHTED") ||
                    ($scope.uiState.weightMethod === "SAMPLE_WEIGHT" && $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "MACRO");
                if (shouldDisplayWarningOnMetricsWeightingStrategy) {
                    Dialogs.confirm(
                        $scope, "Changing averaging method for metrics",
                        $scope.mlTaskDesign.modeling.metrics.classAveragingMethod === "MACRO" ? warningChangeAveragingMethodToWeighted + "The metrics will be computed for each class and the global metrics will be their average weighted by the sum of the sample weights." : warningChangeAveragingMethodToUnweighted
                    ).then(onChangeWeightMethod, () => { $scope.uiState.weightMethod = $scope.mlTaskDesign.weight.weightMethod; })
                } else {
                    onChangeWeightMethod();
                }
            }
        }
    };

    function onChangeWeightMethod() {
        if ($scope.uiState.weightMethod === "NO_WEIGHTING" || $scope.uiState.weightMethod === "CLASS_WEIGHT") {
            // free previous weight variable by setting its role as INPUT
            if ($scope.mlTaskDesign.weight.sampleWeightVariable) {
                if ($scope.mlTaskDesign.preprocessing.per_feature[$scope.mlTaskDesign.weight.sampleWeightVariable]) {
                    $scope.mlTaskDesign.preprocessing.per_feature[$scope.mlTaskDesign.weight.sampleWeightVariable].role = "INPUT";
                }
            }

            // reinitialize the weight variable params in UI and mlTaskDesign
            $scope.uiState.sampleWeightVariable = null;
            $scope.mlTaskDesign.weight.sampleWeightVariable = null;
        }
        if ($scope.isMulticlass()) {
            // change the averaging method for one-vs-all metrics to something consistent with the weighting strategy
            if ($scope.uiState.weightMethod === "NO_WEIGHTING" || $scope.uiState.weightMethod === "SAMPLE_WEIGHT") {
                $scope.mlTaskDesign.modeling.metrics.classAveragingMethod  = "WEIGHTED";
            } else {
                $scope.mlTaskDesign.modeling.metrics.classAveragingMethod  = "MACRO";
            }
        }

        $scope.mlTaskDesign.weight.weightMethod = $scope.uiState.weightMethod;
        $scope.saveSettings();
    }

    $scope.onChangeSampleWeightVariable = function() {
        if($scope.uiState.sampleWeightVariable){
            var deferred = $q.defer();
            CreateModalFromTemplate("/templates/analysis/prediction/change-weight-modal.html", $scope, null, function(newScope) {
                newScope.deferred = deferred;
                newScope.confirm = function() {
                    // free previous weight variable by setting its role as INPUT
                    if($scope.mlTaskDesign.weight.sampleWeightVariable){
                        if ($scope.mlTaskDesign.preprocessing.per_feature[$scope.mlTaskDesign.weight.sampleWeightVariable]) {
                            $scope.mlTaskDesign.preprocessing.per_feature[$scope.mlTaskDesign.weight.sampleWeightVariable].role = "INPUT";
                        }
                    }
                    $scope.mlTaskDesign.weight.sampleWeightVariable = $scope.uiState.sampleWeightVariable;
                    let featureData = $scope.mlTaskDesign.preprocessing.per_feature[$scope.uiState.sampleWeightVariable];
                    featureData.role = "WEIGHT";
                    if (featureData.type != "NUMERIC" || featureData.autoReason) {
                        featureData.missing_handling = "IMPUTE";
                        featureData.missing_impute_with = "MEAN";
                        featureData.numerical_handling = "REGULAR";
                        featureData.rescaling = "AVGSTD";
                        featureData.type = "NUMERIC";
                    }
                    $scope.saveSettings();
                    newScope.deferred.resolve("changed")
                    newScope.dismiss();
                };
                newScope.cancel = function() {
                    newScope.deferred.reject("cancelled")
                    newScope.dismiss();
                };
                newScope.$on("$destroy",function() {
                    if(newScope.deferred) {
                        newScope.deferred.reject("destroyed");
                    }
                    newScope.deferred = null;
                });
            });
            deferred.promise.then(function(a) {
                // nothing to do here
            }, function(a) {
                // reset the UI weight variable to the saved weight variable
                $scope.uiState.sampleWeightVariable = $scope.mlTaskDesign.weight.sampleWeightVariable;
            });
        }
    };

    $scope.potentialWeightFeatures = function() {
        var per_feature = $scope.mlTaskDesign.preprocessing.per_feature;
        return Object.keys(per_feature).filter(x=>per_feature[x].role!=="TARGET");
    }

    $scope.uiState.calibrationMethods = PMLSettings.task.calibrationMethods;

    $scope.isCalibrationEnabled = function() {
        return $scope.mlTaskDesign.calibration.calibrationMethod!='NO_CALIBRATION';
    }

    //Time-based Ordering
    $scope.uiState.gsModes = [['TIME_SERIES_SINGLE_SPLIT', 'Time-based train/validation split'], ['TIME_SERIES_KFOLD', 'Time-based K-fold (with overlap)']];

    $scope.isTimeOrderingEnabled = function() {
        return !!$scope.mlTaskDesign.time && $scope.mlTaskDesign.time.enabled;
    };

    $scope.isTimeVariable = function(feature) {
        return !!$scope.mlTaskDesign.time && $scope.mlTaskDesign.time.enabled && $scope.mlTaskDesign.time.timeVariable == feature._name;
    };

    function isKFoldSplit() {
        return $scope.mlTaskDesign
               && $scope.mlTaskDesign.splitParams
               && $scope.mlTaskDesign.splitParams.kfold;
    };

    $scope.isCompatibleWithStratifiedSplitting = function() {
        return $scope.isMLBackendType('PY_MEMORY') && isKFoldSplit() && $scope.isClassification() && !$scope.isTimeOrderingEnabled();
    };

    $scope.isCompatibleWithGroupKFold = function() {
        return $scope.isMLBackendType('PY_MEMORY') && isKFoldSplit();
    };

    $scope.getCalibrationSetWidth = function() {
        if ($scope.mlTaskDesign.splitParams.kfold) {
            return Math.max($scope.mlTaskDesign.calibration.calibrationDataRatio * 100, 1);
        } else {
            return Math.max($scope.mlTaskDesign.calibration.calibrationDataRatio * $scope.mlTaskDesign.splitParams.ssdTrainingRatio * 100, 1);
        }
    }

    $scope.getKFoldEvalCalibrationSetWidth = function() {
        if (!$scope.isCalibrationEnabled() || $scope.mlTaskDesign.calibration.calibrateOnTestSet) {
            return 0;
        } else {
            return $scope.mlTaskDesign.calibration.calibrationDataRatio * ($scope.mlTaskDesign.splitParams.nFolds - 1) / $scope.mlTaskDesign.splitParams.nFolds * 100;
        }
    }

    $scope.isKFoldEvalCalibrationSetOverflowing = function(fold) {
       return $scope.getKFoldEvalCalibrationSetWidth() >= (100 * (1 - fold / $scope.mlTaskDesign.splitParams.nFolds));
    }

    $scope.getKFoldEvalCalibrationSetOverflowingWidth = function(fold) {
        return $scope.isKFoldEvalCalibrationSetOverflowing(fold) ? ($scope.getKFoldEvalCalibrationSetWidth() - (100 * (1 - fold / $scope.mlTaskDesign.splitParams.nFolds))) : 0;
    }


});

app.controller("_PMLTaskResultController", function($scope, DataikuAPI, $stateParams, WT1) {
    $scope.retrainModel = function (sessionId, fullModelIds, setUiState = false) {
        WT1.event("start-retrain-model", {});
        return DataikuAPI.analysis.pml.retrainStart($scope.analysisCoreParams.projectKey,
            $scope.analysisCoreParams.id, $stateParams.mlTaskId, sessionId, fullModelIds)
            .success(() => {
                fullModelIds.forEach(fmi => { $scope.modelSnippets[fmi].trainInfo.$userRequestedState = false });
                if (setUiState) {
                    $scope.uiState.$userRequestedState = false;
                }
                $scope.initialRefreshAndAutoRefresh();
            }).error(setErrorInScope.bind($scope));
    };
});

app.controller("_TabularPMLTaskResultController", function($scope, $controller, PMLSettings, PMLFilteringService, Fn) {
    $controller("_MLTaskResultsController",{$scope:$scope});
    $controller("_PMLTaskResultController",{$scope:$scope});
    angular.extend($scope, PMLSettings.taskF($scope.mlTasksContext.activeMLTask.backendType));
    angular.extend($scope, PMLSettings.task);
    $scope.algorithmCategories = PMLSettings.algorithmCategories($scope.mlTasksContext.activeMLTask.predictionType);
    $scope.metricMap = PMLFilteringService.metricMap;

    $scope.anySessionModelNeedsHyperparameterSearch = function () {
        return ($scope.selection.sessionModels || []).some(function (x) {
            return (x.gridLength != 1 || x.pluginAlgoCustomGridSearch) && !x.partitionedModelEnabled;
        })
    };

    $scope.anySessionModelHasOptimizationResults = function () {
        if ($scope.isMLBackendType("KERAS")) {
            return ($scope.selection.sessionModels || []).some(function (x) {
                return x.modelTrainingInfo;
            });
        } else {
            return ($scope.selection.sessionModels || []).some(function (x) {
                return x.gridsearchData && x.gridsearchData.gridPoints && x.gridsearchData.gridPoints.length > 0;
            })
        }
    };

    $scope.retrainSession = function(sessionId) {
        const fullModelIds = $scope.selection.allObjects
            .filter(function(model) {
                return model.sessionId === sessionId
                    && ($scope.isModelOptimizationResumable(model) || $scope.isModelRetrainable(model));
            })
            .map(Fn.prop('fullModelId'));

        $scope.retrainModel(sessionId, fullModelIds, true);
    };

});

app.controller("ClassicalPMLTaskResultController", function($scope, $timeout, $controller, DataikuAPI, $stateParams,
                                                            FutureWatcher) {
    $controller("_TabularPMLTaskResultController",{$scope:$scope});
    $controller("_DeepLearningPMLTaskResultController", {$scope: $scope});  // used for Keras

    $scope.tensorboardUrls = {};
    $scope.initializeTensorboardUrl = function(sessionId) {
        let webAppId = `TENSORBOARD_${$scope.analysisCoreParams.projectKey}-${$stateParams.analysisId}-${$stateParams.mlTaskId}-${sessionId}`
        DataikuAPI.webapps.getBackendUrl($scope.analysisCoreParams.projectKey, webAppId, null).success(function(data) {
            $timeout(function() {
                $scope.tensorboardUrls[sessionId] = data.location;
            });
        }).error(setErrorInScope.bind($scope));
    };
    $scope.canShowTensorboard = function () {
        return $scope.sessionTask.backendType === 'KERAS';
    };

    $scope.$watch('sessionTask', (nv) => {
        if (nv) {
            if (nv.tensorboardStatus === undefined) {
                nv.tensorboardStatus = {
                    isShown: false,
                    isBackendReady: false,
                    isFrontendReady: false,
                    showIfFrontIsNotReady: false,
                    fullScreen: false
                };
            } else {
                nv.tensorboardStatus.isFrontendReady = false;
                nv.tensorboardStatus.showIfFrontIsNotReady = true;
            }
        }
    });

    $scope.showHideTensorboard = function () {
        $scope.sessionTask.tensorboardStatus.showIfFrontIsNotReady = false;
        $scope.sessionTask.tensorboardStatus.fullScreen = false;
        $scope.sessionTask.tensorboardStatus.isBackendReady = false;
        $scope.sessionTask.tensorboardStatus.isFrontendReady = false;
        $scope.sessionTask.tensorboardStatus.isShown = !$scope.sessionTask.tensorboardStatus.isShown;
        if ($scope.sessionTask.tensorboardStatus.isShown) {
            let sessionId = $scope.selection.sessionModels[0].sessionId;
            DataikuAPI.webapps.startTensorboard($scope.analysisCoreParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId, sessionId).success(function (result) {
                if (result && result.jobId && !result.hasResult) { // There is a backend still starting, wait for it
                    FutureWatcher.watchJobId(result.jobId)
                        .success(function (data) {
                            $scope.sessionTask.tensorboardStatus.isBackendReady = true;
                        }).error(function (data, status, headers, config, statusText, xhrStatus) {
                        $scope.sessionTask.tensorboardStatus.isShown = false;
                        setErrorInScope.bind($scope)(data, status, headers, config, statusText, xhrStatus);
                    })
                } else {
                    $scope.sessionTask.tensorboardStatus.isBackendReady = true;
                }
            }).error(setErrorInScope.bind($scope));
        }

    };

});

app.controller("_TabularPMLTaskDesignController", function($scope, $controller, $stateParams,
        Dialogs, DataikuAPI, Assert, WT1, Collections, CreateModalFromTemplate, PMLSettings, VisualMlCodeEnvCompatibility, StringUtils, CustomMetricIDService, AlgorithmsSettingsService) {
    angular.extend($scope, PMLSettings.taskF($scope.mlTasksContext.activeMLTask.backendType));
    angular.extend($scope, PMLSettings.task);

    $scope.AlgorithmsSettingsService = AlgorithmsSettingsService;

    $controller("_MLTaskDesignController", { $scope: $scope });


    $scope.reguessAll = function() {
        Dialogs.confirm($scope, "Re-detect settings", "Are you sure you want to re-detect all settings? Your changes will be lost.").then(function(){
            DataikuAPI.analysis.pml.reguess($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId)
                .then(function({data}) {
                    $scope.setMlTaskDesign(data);
                    $scope.savedSettings = dkuDeepCopy($scope.mlTaskDesign, PMLSettings.noDollarKey);
                    $scope.fillUISplitParams($scope.mlTaskDesign.splitParams);
                })
                .then($scope.refreshStatus)
                .catch(setErrorInScope.bind($scope));
        });
    };

    $scope.onChangeTargetFeature = function() {
        if ($scope.dirtySettings()) {
            $scope.saveSettings();
        }
        CreateModalFromTemplate("/templates/analysis/prediction/change-core-params-modal.html", $scope, "PMLChangeBasicParamsModal", function(newScope) {
            newScope.paramKey = "targetVariable";
        });
    };

    $scope.onChangePredictionType = function(){
        if ($scope.dirtySettings()) {
            $scope.saveSettings();
        }
        CreateModalFromTemplate("/templates/analysis/prediction/change-core-params-modal.html", $scope, "PMLChangeBasicParamsModal", function(newScope) {
            newScope.paramKey = "predictionType";
        });
    };

    $scope.displayAlgosAsBaseLearners = function() {
        return $scope.isCausalPrediction() && $scope.uiState.selectedCausalMethod === 'META_LEARNER';
    };

    $scope.getAlgorithmTemplate = function() {
        Assert.inScope($scope, 'algorithms');
        const algorithm = $scope.uiState && $scope.uiState.algorithm;
        if (!algorithm) return null;

        const templatePathPrefix = '/templates/analysis/prediction/settings/algorithms/' + $scope.mlTaskDesign.backendType.toLowerCase();
        if (algorithm.startsWith("CustomPyPredAlgo_")) {
            return templatePathPrefix + '/plugin-model.html';
        }
        if (algorithm.startsWith("custom")) {
            return templatePathPrefix + '/custom.html';
        }

        const availableAlgorithms = $scope.algorithms[$scope.mlTaskDesign.backendType];
        const { templateName } = Collections.indexByField(availableAlgorithms, 'algKey')[algorithm];
        return templatePathPrefix + "/" + (templateName || algorithm) + ".html";
    };

    $scope.addCustomPython = function() {
        $scope.mlTaskDesign.modeling.custom_python = $scope.mlTaskDesign.modeling.custom_python || [];

        let code = PMLSettings.defaultCustomCode($scope.mlTaskDesign.backendType, $scope.isRegression(), $scope.mlTaskDesign.targetColumn);
        $scope.mlTaskDesign.modeling.custom_python.push({
            enabled: true,
            name: "Custom Python model",
            code
        });

        $scope.setAlgorithms($scope.mlTaskDesign);
        $scope.uiState.algorithm = 'custom_python_' + ($scope.mlTaskDesign.modeling.custom_python.length - 1);
        $scope.uiState.scrollToMeAlgorithm = $scope.uiState.algorithm;
    };

    $scope.canAddCustomPython = function() {
            if ($scope.isCausalPrediction() && $scope.uiState.selectedCausalMethod !== 'META_LEARNER') return false;
        return $scope.mlTaskDesign.backendType === 'PY_MEMORY' && !$scope.isTimeseriesPrediction();
    };

    // For now only used for classical, but more convenient to keep next to addCustomPython
    $scope.addCustomMLLib = function() {
        $scope.mlTaskDesign.custom_mllib = $scope.mlTaskDesign.custom_mllib || [];

        let initializationCode = PMLSettings.defaultCustomCode($scope.mlTaskDesign.backendType, $scope.isRegression(), $scope.mlTaskDesign.targetVariable);

        $scope.mlTaskDesign.modeling.custom_mllib.push({
            enabled: true,
            name: "Custom MLlib model",
            initializationCode
        });
        $scope.setAlgorithms($scope.mlTaskDesign);
        $scope.uiState.algorithm = 'custom_mllib_' + ($scope.mlTaskDesign.modeling.custom_mllib.length - 1);
        $scope.uiState.scrollToMeAlgorithm = $scope.uiState.algorithm;
    };

    $scope.getCrossvalModes = function() {
        if($scope.mlTaskDesign.time && $scope.mlTaskDesign.time.enabled) {
            return $scope.crossvalModesWithTime;
        } else {
            return $scope.crossvalModesRandom;
        }
    };

    // TODO: extract in its own controller
    $scope.copyAlgorithmSettings = function(exportSettings) {
        if ($scope.dirtySettings()) {
                $scope.saveSettings();
        }
        DataikuAPI.projects.listHeads(exportSettings ? 'WRITE_CONF' : null).success(function(projectData) {
             CreateModalFromTemplate("/templates/analysis/mlcommon/settings/copy-settings.html", $scope, null, function(newScope) {
                 newScope.projects = projectData;
                 newScope.title = "Copy "
                    + ($scope.mlTaskDesign.backendType === "KERAS" ? "architecture " : " algorithms ")
                    + (exportSettings ? "to" : "from");
                 newScope.totem = "icon-" + (exportSettings ? "copy" : "paste");
                 newScope.infoMessages = [
                    "You can only choose a " + $scope.displayTypes[$scope.mlTaskDesign.predictionType]
                    + " model using a " + ($scope.backendTypeNames[$scope.mlTaskDesign.backendType] || $scope.mlTaskDesign.backendType)
                    + " engine"
                ];
                 newScope.selectProject = function() {
                     DataikuAPI.analysis.listHeads(newScope.selectedProjectKey).success(function(analysisData) {
                         newScope.analyses = analysisData;
                         newScope.selectedAnalysisId = undefined;
                         newScope.selectedTask = undefined;
                     }).error(setErrorInScope.bind($scope));
                 };
                 newScope.selectAnalysis = function () {
                    DataikuAPI.analysis.listMLTasks(newScope.selectedProjectKey, newScope.selectedAnalysisId)
                    .success(function(taskData) {
                        newScope.descriptions = [];
                        newScope.tasks = taskData;
                        newScope.tasks.forEach(task => {
                            // task can be selected if not current one + same backend + same pred type (or both classif)
                            task.isNotSelectable = task.mlTaskId === $stateParams.mlTaskId
                                && newScope.selectedAnalysisId === $stateParams.analysisId
                                && newScope.selectedProjectKey === $stateParams.projectKey
                                || task.backendType !== $scope.mlTaskDesign.backendType
                                || task.taskType !== "PREDICTION"
                                || $scope.displayTypes[$scope.mlTaskDesign.predictionType] !== $scope.displayTypes[task.predictionType];

                            newScope.descriptions.push($scope.displayTypes[task.predictionType || task.taskType] + " ("
                            + ($scope.backendTypeNames[task.backendType] || task.backendType) + ")");
                        });
                        newScope.selectedTask = undefined;
                    }).error(setErrorInScope.bind($scope));
                 };
                 if (newScope.projects.some(_ => _.projectKey === $stateParams.projectKey)) {
                      newScope.selectedProjectKey = $stateParams.projectKey;
                      newScope.analyses = $scope.analyses;
                      newScope.selectedAnalysisId = $stateParams.analysisId;
                      newScope.selectAnalysis();
                  }
                 newScope.confirm = function() {
                    const currentIds = [
                       $stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId
                    ];

                    const selectedIds = [
                        newScope.selectedProjectKey, newScope.selectedAnalysisId, newScope.selectedTask.mlTaskId
                    ];

                    const originIds = exportSettings ? currentIds : selectedIds;
                    const destinationIds = exportSettings ? selectedIds : currentIds;
                    DataikuAPI.analysis.pml.copyAlgorithmSettings(...originIds, ...destinationIds)
                        .success(function(data) {
                            if (!exportSettings) {
                                $scope.setMlTaskDesign(data);
                                $scope.setAlgorithms($scope.mlTaskDesign);
                            }
                            newScope.dismiss();
                        }).error(setErrorInScope.bind($scope));

                    WT1.event("mltask-copy-algorithms", {
                        export: exportSettings,
                        sameProject: $stateParams.projectKey === newScope.selectedProjectKey,
                        sameAnalysis: $stateParams.analysisId === newScope.selectedAnalysisId,
                        typeDest: newScope.selectedTask.predictionType,
                        typeSrc: $scope.mlTaskDesign.predictionType
                    });
                 };
                 newScope.cancel = function() {
                     newScope.dismiss();
                 };
             });
         }).error(setErrorInScope.bind($scope));
    };

    $scope.displayDeepNeuralNetworkCodeEnvWarning = function(algorithm) {
        if (!($scope.isClassicalPrediction() || $scope.isCausalPrediction())
            || !AlgorithmsSettingsService.getAlgorithmSettings($scope.mlTaskDesign, algorithm).enabled
            || !["deep_neural_network_regression", "deep_neural_network_classification"].includes(algorithm.algKey)) {
            return false;
        }

        const envCompat = VisualMlCodeEnvCompatibility.getCodeEnvCompat($scope.mlTaskDesign.envSelection, $scope.codeEnvsCompat);
        const isEnvDeepNeuralNetworkCompatible = envCompat && envCompat.deepNeuralNetwork && envCompat.deepNeuralNetwork.compatible;
        return !isEnvDeepNeuralNetworkCompatible;
    };

    $scope.displayWeightWarning = function(algorithm) {
        return $scope.isClassicalPrediction()
            && AlgorithmsSettingsService.getAlgorithmSettings($scope.mlTaskDesign, algorithm).enabled
            && $scope.isSampleWeightEnabled()
            && $scope.algosWithoutWeightSupport.has(algorithm.algKey);
    };

    let customMetricDefaultCode;

    if ($scope.isTimeseriesPrediction()) {
        customMetricDefaultCode = `def score(y_valid, y_pred):
    """
    Custom scoring function.
    Must return a float quantifying the estimator prediction quality.
    - y_valid is a pandas Series
    - y_pred is a numpy ndarray with shape (nb_records,)

    This function is applied to each time series and then averaged over all time series.
    """`
    } else {
        customMetricDefaultCode = `def score(y_valid, y_pred):
    """
    Custom scoring function.
    Must return a float quantifying the estimator prediction quality.
    - y_valid is a pandas Series
    - y_pred is a numpy ndarray with shape:
        - (nb_records,) for regression problems and classification problems
            where 'needs probas' (see below) is false
            (for classification, the values are the numeric class indexes)
        - (nb_records, nb_classes) for classification problems where
            'needs probas' is true`;

        if ($scope.mlTaskDesign.backendType !== "KERAS") {
            customMetricDefaultCode += `
    - [optional] X_valid is a dataframe with shape (nb_records, nb_input_features)
    - [optional] sample_weight is a numpy ndarray with shape (nb_records,)
        NB: this option requires a variable set as "Sample weights"`;
        }
        customMetricDefaultCode += `
    """
        `;
    }

    $scope.getNewMetricTemplate = function() {
        let name = StringUtils.transmogrify("Custom Metric #" + ($scope.mlTaskDesign.modeling.metrics.customMetrics.length + 1).toString(),
            $scope.mlTaskDesign.modeling.metrics.customMetrics.map(a => a.name),
            function(i){return "Custom Metric #" + (i+1).toString() }
        );

        const template = {
           name,
           metricCode: customMetricDefaultCode,
           description:"",
           type: 'MODEL',
           $foldableOpen:true,
           greaterIsBetter:true,
           needsProbability: false
       };
       return template;
    }

    $scope.snippetCategory = 'py-scoringfunc';

    $scope.addNewCustomMetric = function() {
        if(!('customMetrics' in $scope.mlTaskDesign.modeling.metrics)) {
            $scope.mlTaskDesign.modeling.metrics.customMetrics = [];
        }

        $scope.mlTaskDesign.modeling.metrics.customMetrics.push($scope.getNewMetricTemplate());
        $scope.fireCustomMetricAddedWT1Event();
    }

    $scope.fireCustomMetricAddedWT1Event = function() {
        WT1.event("clicked-item", {"item-id": 'mltask-add-custom-metric'});
    };

    $scope.fireCustomMetricRemovedWT1Event = function() {
        WT1.event("clicked-item", {"item-id": 'mltask-remove-custom-metric'});
    };

    $scope.toggleFoldable = function(index) {
        if($scope.mlTaskDesign.modeling.metrics.customMetrics[index]){
            $scope.mlTaskDesign.modeling.metrics.customMetrics[index].$foldableOpen = !$scope.mlTaskDesign.modeling.metrics.customMetrics[index].$foldableOpen
        }
    }

    $scope.setEvaluationMetric = function() {
        let metric = $scope.uiState.evaluationMetricId;
        if (!metric) {
            $scope.uiState.evaluationMetricId = PMLSettings.task.getDefaultEvaluationMetric($scope.mlTaskDesign.predictionType);
            metric = $scope.uiState.evaluationMetricId;
        }
        if (CustomMetricIDService.checkMetricIsCustom(metric)) {
            $scope.mlTaskDesign.modeling.metrics.customEvaluationMetricName = CustomMetricIDService.getCustomMetricName(metric);
            $scope.mlTaskDesign.modeling.metrics.evaluationMetric = "CUSTOM";
        } else {
            $scope.mlTaskDesign.modeling.metrics.customEvaluationMetricName = undefined;
            $scope.mlTaskDesign.modeling.metrics.evaluationMetric = metric;
        }
    }
})

app.controller("_DeepLearningPMLTaskResultController", function($scope) {
    $scope.anyModelHasOneEpochFinished = function() {
        return ($scope.selection.sessionModels || []).some(function(model) {
            return (model.modelTrainingInfo && model.modelTrainingInfo.epochs && model.modelTrainingInfo.epochs.length > 0);
        });
    };

    $scope.modelEpochHasSavedModel = function() {
        return ($scope.selection.sessionModels || []).some(function(model) {
            return model.modelTrainingInfo && model.modelTrainingInfo.keptModelEpoch !== undefined && model.modelTrainingInfo.keptModelEpoch >= 0
            && model.trainInfo && model.trainInfo.state !== 'ABORTED' && model.trainInfo.state !== 'FAILED';
        });
    };

    $scope.modelEpochHasTrainSet = function() {
        return modelEpochHasValue('trainScore') || modelEpochHasValue('trainLoss');
    };

    $scope.modelEpochHasValidationSet = function() {
        return modelEpochHasValue('testScore') || modelEpochHasValue('testLoss');
    };

    $scope.modelEpochHasLossScore = function(){
        return modelEpochHasValue('trainLoss') || modelEpochHasValue('testLoss');
    };

    function modelEpochHasValue(scoreName) {
        return ($scope.selection.sessionModels || []).some(function(model) {
            return (model.modelTrainingInfo && model.modelTrainingInfo.epochs && model.modelTrainingInfo.epochs.some(epoch => epoch[scoreName] !== undefined));
        });
    }

    $scope.anyModelHasAllEpochsFinished = function() {
        return ($scope.selection.sessionModels || []).some(function(model) {
            // model can be done earlier with early stopping
            return (model.modelTrainingInfo && model.modelTrainingInfo.nbEpochs == model.modelTrainingInfo.epochs.length
                    || model.trainInfo && model.trainInfo.state === "DONE");
        });
    };

    $scope.anyModelHasFailedOrAborted = function() {
        return $scope.anyModelHasFailed() || anyModelAborted();
    };

    $scope.anyModelHasFailed = function() {
        return ($scope.selection.sessionModels || []).some(function(model) {
            return model.trainInfo.state === 'FAILED';
        });
    };

    const anyModelAborted = function() {
        return ($scope.selection.sessionModels || []).some(function(model) {
            return model.trainInfo.state === 'ABORTED';
        });
    };
});

app.controller("ClassicalPMLTaskDesignController", function($scope, Fn, $controller, $state, PMLSettings) {
    $scope.$state = $state;
    $controller("_TabularPMLTaskDesignController", {$scope: $scope});

    $scope.predictionTypes = $scope.predictionTypes.filter(type => type.classical);

    $scope.isClassification = function(){
        if (!$scope.mlTaskDesign) return false;
        return PMLSettings.task.isClassification($scope.mlTaskDesign.predictionType);
    }

    $scope.uiState.generatorPage = "manual_interactions";

    $scope.countNumericCombinations = function() {
        var n = Object.keys($scope.mlTaskDesign.preprocessing.per_feature)
                    .map(Fn(Fn.dict($scope.mlTaskDesign.preprocessing.per_feature), Fn.prop('type')))
                    .filter(Fn.eq('NUMERIC')).length;
        return n < 2 ? 0 : (n * (n-1) / 2); // n take 2
    };

    $scope.addInteraction = function(){
        var prep = $scope.mlTaskDesign.preprocessing;
        var fs = Object.keys(prep.per_feature).filter(function(f){ return prep.per_feature[f].role == "INPUT"; });
        var interaction = {
            column_1: fs[0],
            column_2: fs.length > 1 ? fs[1] : fs[0],
            rescale: true,
            max_features: 100
        }
        var ints = prep.feature_generation.manual_interactions.interactions;
        if(!ints){
            ints = [];
            prep.feature_generation.manual_interactions.interactions = ints;
        }
        ints.push(interaction);
    };

    $scope.activeFeatures = function(){
        var feats = [];
        for(var f in $scope.mlTaskDesign.preprocessing.per_feature){
            if($scope.mlTaskDesign.preprocessing.per_feature[f].role == 'INPUT'){
                feats.push(f);
            }
        }
        feats.sort();
        return feats;
    };

    $scope.keepNanFeatures = function(){
        return Object.keys($scope.mlTaskDesign.preprocessing.per_feature).filter(x => (($scope.mlTaskDesign.preprocessing.per_feature[x].role == 'INPUT') &&
                                                                                ($scope.mlTaskDesign.preprocessing.per_feature[x].type == "NUMERIC") &&
                                                                                (["KEEP_NAN_OR_IMPUTE", "KEEP_NAN_OR_DROP"].includes($scope.mlTaskDesign.preprocessing.per_feature[x].missing_handling))));
    }

    $scope.removeInteraction = function(i){
        $scope.mlTaskDesign.preprocessing.feature_generation.manual_interactions.interactions.splice(i, 1);
    };

    $scope.willDummify = function(interaction){
        var isNumeric = function(f){
            return $scope.mlTaskDesign.preprocessing.per_feature[f].type == "NUMERIC";
        };
        return ! (isNumeric(interaction.column_1) && isNumeric(interaction.column_2));
    };

    $scope.$watch('mlTaskDesign', function(nv){
        if (nv) {
            $scope.uiState.predictionType = nv.predictionType;
            $scope.uiState.sampleWeightVariable = nv.weight.sampleWeightVariable ? nv.weight.sampleWeightVariable : null;
            $scope.uiState.weightMethod = nv.weight.weightMethod ? nv.weight.weightMethod : null;
            $scope.uiState.managedFolderSmartId = nv.managedFolderSmartId;
            if (nv.backendType === "PY_MEMORY") {
                $scope.uiState.hyperparamSearchStrategies =  [["GRID", "Grid search"],
                                                              ["RANDOM", "Random search"],
                                                              ["BAYESIAN", "Bayesian search"]];
            } else {
                $scope.uiState.hyperparamSearchStrategies =  [["GRID", "Grid search"]];
            }
        }
    });

    $scope.$watch('mlTaskDesign.time', function(nv, ov){
        // Propagate changes of `mlTaskDesign.time` object to:
        //   - split params
        //   - per feature
        //   - grid search params
        // Be careful to propagate only if actual change in order not to dirtify the mlTaskDesign object for nothing
        if(nv && ov && nv !== ov) {
            const splitSingleDataset = $scope.mlTaskDesign.splitParams.ttPolicy === "SPLIT_SINGLE_DATASET";
            if (nv.timeVariable && nv.timeVariable !== ov.timeVariable) {
                let featureData = $scope.mlTaskDesign.preprocessing.per_feature[nv.timeVariable];
                featureData.missing_handling = "DROP_ROW";
                featureData.autoReason = null;
                if (splitSingleDataset) {
                    $scope.mlTaskDesign.splitParams.ssdColumn = nv.timeVariable;
                }
            }

            if (nv.ascending !== ov.ascending && splitSingleDataset) {
                $scope.mlTaskDesign.splitParams.testOnLargerValues = nv.ascending;
            }

            if (nv.enabled !== ov.enabled) {
                if (nv.enabled) {
                    $scope.mlTaskDesign.splitParams.ssdSplitMode = "SORTED";
                    $scope.uiState.splitMethodDesc = "Based on time variable";
                    switch ($scope.mlTaskDesign.modeling.gridSearchParams.mode) {
                        case "KFOLD":
                            $scope.mlTaskDesign.modeling.gridSearchParams.mode = "TIME_SERIES_KFOLD";
                            break;
                        case "SHUFFLE":
                            $scope.mlTaskDesign.modeling.gridSearchParams.mode = "TIME_SERIES_SINGLE_SPLIT";
                            break;
                        default:
                            break;
                    }
                    $scope.mlTaskDesign.splitParams.kfold = false;
                } else {
                    $scope.mlTaskDesign.splitParams.ssdSplitMode = "RANDOM";
                    $scope.uiState.splitMethodDesc = "Randomly";
                    switch ($scope.mlTaskDesign.modeling.gridSearchParams.mode) {
                        case "TIME_SERIES_KFOLD":
                            $scope.mlTaskDesign.modeling.gridSearchParams.mode = "KFOLD";
                            break;
                        case "TIME_SERIES_SINGLE_SPLIT":
                            $scope.mlTaskDesign.modeling.gridSearchParams.mode = "SHUFFLE";
                            break;
                        default:
                            break;
                    }
                    $scope.mlTaskDesign.time.timeVariable = null;
                    if ($scope.mlTaskDesign.splitParams.ssdColumn) {
                        $scope.mlTaskDesign.splitParams.ssdColumn = null;
                    }
                }
            }
        }
    }, true);
});

app.controller("PMLChangeBasicParamsModal", function($scope, $stateParams, DataikuAPI, StringUtils, Logger, AlgorithmsSettingsService) {
    const deregister = $scope.$watch("paramKey", function(nv) {
        if (!nv) return;

        $scope.recommendRedetect = false;
        $scope.recommendKeepSettings = true;
        $scope.redetectedSettings = "all your design settings";

        let reguessMethod;
        switch($scope.paramKey) {
        case "predictionType":
            $scope.changedParamName = "prediction type";
            reguessMethod = DataikuAPI.analysis.pml.reguessWithType;
            break;
        case "targetVariable":
            $scope.changedParamName = $scope.isCausalPrediction() ? "outcome" : "target";
            reguessMethod = DataikuAPI.analysis.pml.reguessWithTarget;
            break;
        case "timeseriesIdentifiers":
            $scope.changedParamName = "time series identifiers";
            reguessMethod = DataikuAPI.analysis.pml.reguessWithTimeseriesIdentifiers;
            break;
        case "timeVariable":
            $scope.changedParamName = "time variable";
            reguessMethod = DataikuAPI.analysis.pml.reguessWithTimeVariable;
            break;
        case "timestepParams":
            $scope.changedParamName = "time steps";
            $scope.recommendRedetect = true;
            $scope.recommendKeepSettings = false;
            $scope.redetectedSettings = "the season lengths (for algorithms Seasonal naive, AutoARIMA, Seasonal trend)";
            reguessMethod = function(projectKey, analysisId, mlTaskId, changedParam, redetect) {
                return DataikuAPI.analysis.pml.reguessWithTimestepParams(projectKey, analysisId, mlTaskId, changedParam, null, redetect);
            }
            break;
        case "predictionLength":
            $scope.changedParamName = "forecast horizon";
            $scope.recommendRedetect = true;
            $scope.recommendKeepSettings = false;
            $scope.redetectedSettings = "the context lengths (for algorithms NPTS, Simple Feed Forward, DeepAR, Transformer, MQ-CNN)";
            reguessMethod = function(projectKey, analysisId, mlTaskId, changedParam, redetect) {
                return DataikuAPI.analysis.pml.reguessWithTimestepParams(projectKey, analysisId, mlTaskId, null, changedParam, redetect);
            }
            break;
        case "treatmentVariable":
            $scope.changedParamName = "treatment variable";
            reguessMethod = DataikuAPI.analysis.pml.reguessWithTreatmentVariable;
            break;
        default:
            Logger.error("Wrong prediction parameter key: " + $scope.paramKey);
            return;
        }

        const changedParam = $scope.uiState[$scope.paramKey];

        $scope.shouldRenameMLTask = $scope.paramKey === "targetVariable" || $scope.paramKey === "treatmentVariable";
        if ($scope.shouldRenameMLTask) {
            let suggestedName;
            if ($scope.isTimeseriesPrediction()) {
                suggestedName = "Forecast " + changedParam;
            } else if ($scope.isCausalPrediction()) {
                if ($scope.paramKey === "targetVariable") {
                    suggestedName = `Predict effect of ${$scope.mlTaskDesign.treatmentVariable} on ` + changedParam;
                } else if ($scope.paramKey === "treatmentVariable") {
                    suggestedName = `Predict effect of ${changedParam} on ${$scope.mlTaskDesign.targetVariable}`;
                }
            } else {
                suggestedName = "Predict " + changedParam;
            }
            $scope.newName = StringUtils.transmogrify(suggestedName, $scope.mlTasksContext.analysisMLTasks.map(_ => _.name));
        }
        const classicalPredictionTypes = ['BINARY_CLASSIFICATION', 'MULTICLASS', 'REGRESSION'];
        $scope.loseMetrics = classicalPredictionTypes.includes($scope.mlTaskDesign.predictionType)
            && $scope.paramKey === "predictionType";

        $scope.loseAssertions = classicalPredictionTypes.includes($scope.mlTaskDesign.predictionType)
            && ($scope.paramKey === "predictionType" || $scope.paramKey === "targetVariable");

        $scope.loseOverrides = classicalPredictionTypes.includes($scope.mlTaskDesign.predictionType)
            && ($scope.paramKey === "predictionType" || $scope.paramKey === "targetVariable");

        $scope.loseAlgo = $scope.paramKey === "predictionType"
            && $scope.mlTaskDesign.backendType !== "KERAS"
            && ["CAUSAL_REGRESSION", "REGRESSION"].includes(changedParam) !== $scope.isRegression();
        $scope.loseArchitecture = $scope.mlTaskDesign.backendType === "KERAS";

        if ($scope.paramKey === "targetVariable" && $scope.mlTaskDesign.weight) {
            $scope.loseWeight = $scope.mlTaskDesign.weight.sampleWeightVariable === changedParam;
        }
        if ($scope.paramKey === "predictionType") {
            $scope.loseWeight = $scope.mlTaskDesign.weight && ["CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"].includes($scope.mlTaskDesign.weight.weightMethod)
                && $scope.loseAlgo;
        }
        $scope.loseGapSize = $scope.paramKey === "predictionLength" && ($scope.mlTaskDesign.evaluationParams.gapSize > changedParam - 1);

        $scope.swapTreatmentAndOutcome = $scope.paramKey === "targetVariable" && $scope.mlTaskDesign.treatmentVariable === changedParam;

        $scope.loseSomeSettings = $scope.loseMetrics || $scope.loseAlgo || $scope.loseWeight || $scope.loseArchitecture || $scope.loseAssertions
            || $scope.loseOverrides || $scope.loseGapSize || $scope.swapTreatmentAndOutcome;

        $scope.confirm = function(redetect) {
            reguessMethod($stateParams.projectKey, $stateParams.analysisId,
                $stateParams.mlTaskId, changedParam, redetect).then(function(response){
                    $scope.setMlTaskDesign(response.data);
                    if ($scope.shouldRenameMLTask) {
                        $scope.mlTaskDesign.name = $scope.newName;
                    }

                    if ($scope.mlTaskDesign.backendType === "KERAS") {
                        $scope.fillBuildCodeKeras(true);
                    }
                    $scope.setAlgorithms($scope.mlTaskDesign);
                    $scope.setSelectedAlgorithm(AlgorithmsSettingsService.getDefaultAlgorithm(
                        $scope.mlTaskDesign,
                        $scope.algorithms[$scope.mlTaskDesign.backendType]
                    ));
                    $scope.fillUISplitParams($scope.mlTaskDesign.splitParams);
                    $scope.saveSettings().then($scope.dismiss);

            }, function(data, status, headers) {
                setErrorInScope.bind($scope)(data, status, headers);
            });
        };

        $scope.$on("$destroy", function() {
            const getUIStateParam = $scope.getUIStateParam || (param => param);
            $scope.uiState[$scope.paramKey] = getUIStateParam($scope.mlTaskDesign[$scope.paramKey]);
            $scope.onCloseCallback && $scope.onCloseCallback();
        });

        deregister();
    });
});


app.controller("PMLTaskPreTrainModal", function ($scope, $stateParams, $controller, $filter, DataikuAPI, GPU_SUPPORTING_CAPABILITY, GpuUsageService) {
    $controller('_PMLTaskWithK8sContainerInformationController', { $scope });
    $scope.GPU_SUPPORTING_CAPABILITY = GPU_SUPPORTING_CAPABILITY;

    $scope.uiState = {
        confirmRun: false,
        gpu: {
            name: getGpuTitle(),
        }
    };

    $scope.getPreTrainStatus = function() {
        DataikuAPI.analysis.pml.getPreTrainStatus($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId).success(function(data) {
            $scope.preTrainStatus = data;
            $scope.splitStatus = data.splitStatus;
            $scope.uiState.anyError = data.messages.some(x => x.severity == 'ERROR');
            $scope.uiState.anyWarning = data.messages.some(x => x.severity == 'WARNING');
        }).error(setErrorInScope.bind($scope));
    };

    $scope.updatePreTrainStatus = function() {
        $scope.saveSettings().then(x => {
            $scope.getPreTrainStatus();
        })
    };

    $scope.getPreTrainStatus();

    $scope._doTrainThenResolveModal = function() {
        $scope._doTrain().then($scope.resolveModal);
    };

    $scope.useExtracts = function () {
        return $scope.splitStatus.isKFold || $scope.isTimeseriesPrediction();
    };

    $scope.shouldDisableTrain = function(gpuCapability) {
        return gpuCapability && $scope.mlTaskDesign.gpuConfig.params.useGpu && !$scope.mlTaskDesign.gpuConfig.params.gpuList.length;
    };


    $scope.availableGpuCapabilities = $scope.getAvailableGpuCapabilities(); // everything possible given the task
    $scope.usedGpuCapabilities = $scope.getUsedGpuCapabilities(); // everything actually present on the task (e.g. is text embedding in use, is xgboost on)
    $scope.usedGpuCapabilitiesWithGpuOn = $scope.getUsedGpuCapabilitiesWithGpuOn(); // everything present that isn't explicitly disabled in runtime config

    $scope.showGpuCapabilityWarning = function() {
        // we warn when gpu is enabled, and there are capabilities that can use a gpu present on the task, but everything has actually been disabled
        return $scope.mlTaskDesign.gpuConfig.params.useGpu && $scope.usedGpuCapabilitiesWithGpuOn.length === 0 && $scope.usedGpuCapabilities.length > 0;
    }

    function getGpuTitle() {
        let title = 'Activate GPU';
        const usedGpuCapabilitiesWithGpuOn = $scope.getUsedGpuCapabilitiesWithGpuOn();

        if (usedGpuCapabilitiesWithGpuOn.length > 0) {
            title += " for "
            title += $filter('andList')(getGpuCapabilityActivityNames(usedGpuCapabilitiesWithGpuOn));
        }
        return title;
    }

    function getGpuCapabilityActivityNames(gpuCapabilities) {
        const formattedNames = []

        for (const item of gpuCapabilities) {
            formattedNames.push(GpuUsageService.CAPABILITIES[item].name);
        }

        return formattedNames;
    }
});

app.controller("_ClassicalPMLTaskPreTrainBase", function($scope, $stateParams, $controller, WT1, Logger, DataikuAPI) {
    $controller("_TabularPMLTaskPreTrainBase", { $scope });

    $scope._doTrain = function () {
        try {
            const algorithms = {};
            $.each($scope.mlTaskDesign.modeling, function (alg, params) {
                if (params.enabled) {
                    algorithms[alg] = params;
                }
            });

            // Adding custom py algorithms
            $.each($scope.mlTaskDesign.modeling.custom_python, function(algNum, params) {
                if (params.enabled) {
                    algorithms["CUSTOM_PYTHON_" + algNum] = params;
                }
            });

            // Adding custom mllib algorithms
            $.each($scope.mlTaskDesign.modeling.custom_mllib, function(algNum, params) {
                if (params.enabled) {
                    algorithms["CUSTOM_MLLIB_" + algNum] = params;
                }
            });

            // Adding plugin algorithms
            $.each($scope.mlTaskDesign.modeling.plugin_python, function(alg, params) {
                if (params.enabled) {
                    algorithms[alg] = params;
                }
            });

            function redactSensitiveInformation(eventContent) {
                const redacted = dkuDeepCopy(eventContent, $scope.SettingsService.noDollarKey); // don't want to delete actual values in scope
                if (redacted.metrics) {
                    if (redacted.metrics.customMetrics) {
                        redacted.metrics.customMetrics.forEach(item => {
                            delete item.name;
                            delete item.description;
                            delete item.metricCode;
                        });
                    }
                }

                if (redacted.algorithms) {
                    if (redacted.algorithms.keras) {
                        delete redacted.algorithms.keras.buildCode;
                        delete redacted.algorithms.keras.fitCode;
                        delete redacted.algorithms.keras.kerasInputs;
                    }

                    for (const algorithm in redacted.algorithms) {
                        if (algorithm.includes('CUSTOM_PYTHON_')) {
                            delete redacted.algorithms[algorithm].name;
                            delete redacted.algorithms[algorithm].code;
                        } else if (algorithm.includes('CUSTOM_MLLIB_')) {
                            delete redacted.algorithms[algorithm].initializationCode;
                        }
                    }
                }

                if (redacted.feature_selection_params && redacted.feature_selection_params.custom_params) {
                    delete redacted.feature_selection_params.custom_params.code;
                }

                if (redacted.feature_generation &&
                    redacted.feature_generation.manual_interactions &&
                    redacted.feature_generation.manual_interactions.interactions) {
                    for (const interaction of redacted.feature_generation.manual_interactions.interactions) {
                        delete interaction.column_1;
                        delete interaction.column_2;
                    }
                }
                return redacted;
            }

            let wt1Content = redactSensitiveInformation({
                 backendType: $scope.mlTaskDesign.backendType,
                 taskType: $scope.mlTaskDesign.taskType,
                 predictionType: $scope.mlTaskDesign.predictionType,
                 guessPolicy: $scope.mlTaskDesign.guessPolicy,
                 feature_generation: $scope.mlTaskDesign.preprocessing.feature_generation,
                 feature_selection_params: $scope.mlTaskDesign.preprocessing.feature_selection_params,
                 algorithms: algorithms,
                 metrics: $scope.mlTaskDesign.modeling.metrics,
                 weightMethod: $scope.mlTaskDesign.weight.weightMethod,
                 hasSessionName: !!$scope.uiState.userSessionName,
                 hasSessionDescription: !!$scope.uiState.userSessionDescription,
                 calibrationMethod: $scope.mlTaskDesign.calibration.calibrationMethod,
                 hasTimeOrdering: $scope.mlTaskDesign.time.enabled,
                 gridSearchParams: $scope.mlTaskDesign.modeling.gridSearchParams,
                 runsOnKubernetes: $scope.hasSelectedK8sContainer(),
                 assertionsParams: aggregateAssertionsParams(),
                 overridesParams: aggregateOverridesParams(),
                 monotonicConstraintParams: aggregateMonotonicConstraintParams()
            });
            WT1.event("prediction-train", wt1Content);
        } catch (e) {
            Logger.error('Failed to report mltask info', e);
        }
        return DataikuAPI.analysis.pml.trainStart($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId,
            $scope.uiState.userSessionName, $scope.uiState.userSessionDescription, $scope.uiState.forceRefresh, true).error(setErrorInScope.bind($scope));
    };

    function aggregateConditionsParams(params, modesKey) {
        let modes = {};
        params.map(a => a.filter.uiData.mode).forEach(val => modes[val] = (modes[val] || 0) + 1)
        let out = {
            "count": params.length || 0,
        };
        out[modesKey] = modes;
        return out
    }

    function aggregateAssertionsParams() {
        return aggregateConditionsParams($scope.mlTaskDesign.assertionsParams.assertions, "assertionsConditionsModes")
    }

    function aggregateOverridesParams() {
        return aggregateConditionsParams($scope.mlTaskDesign.overridesParams.overrides, "overridesConditionsModes")
    }

    function aggregateMonotonicConstraintParams() {
        const perFeatureParams = $scope.mlTaskDesign.preprocessing.per_feature
        const increaseConstraints = Object.keys(perFeatureParams).filter(col => perFeatureParams[col].role === "INPUT" && perFeatureParams[col].monotonic && perFeatureParams[col].monotonic === "INCREASE")
        const decreaseConstraints = Object.keys(perFeatureParams).filter(col => perFeatureParams[col].role === "INPUT" && perFeatureParams[col].monotonic && perFeatureParams[col].monotonic === "DECREASE")
        return {
            "total": increaseConstraints.length + decreaseConstraints.length,
            "increase": increaseConstraints.length,
            "decrease": decreaseConstraints.length
        }
    }

    $scope.displayMessages = function() {
        return $scope.preTrainStatus.messages.length || $scope.mlTaskDesign.modeling.xgboost.enable_cuda || $scope.uiState.selectedAlgorithmsWithWeightIncompatibility.length;
    };
});

app.controller("_TabularPMLTaskPreTrainBase", function($scope, $stateParams, DataikuAPI) {
    $scope.enqueue = function() {
        DataikuAPI.analysis.pml.enqueueSession($stateParams.projectKey, $stateParams.analysisId, $stateParams.mlTaskId, !$scope.mlTaskStatus.training, $scope.uiState.userSessionName, $scope.uiState.userSessionDescription, $scope.uiState.forceRefresh)
            .success(function () {
                $scope.resolveModal();
                $scope.listQueuedSessions();
            })
            .error(setErrorInScope.bind($scope));
    };

    $scope.getModelStr = (pluralize) => {
        let modelStr;
        if ($scope.preTrainStatus && $scope.preTrainStatus.partitionedModelEnabled) {
            modelStr = "partitioned model";
        } else {
            modelStr = "model";
        }

        if (pluralize) {
            modelStr += "s";
        }

        return modelStr;
    };
});

app.controller("ClassicalPMLTaskPreTrainModal", function($scope, $state, $controller) {
    $controller("PMLTaskPreTrainModal", { $scope });
    $controller("_ClassicalPMLTaskPreTrainBase", { $scope });

    $scope.train = function () {
        if (!$scope.isMLBackendType("KERAS")) {
            // check whether gpu config has changed
            if ($scope.dirtySettings()) {
                $scope.saveSettings().then(() => $scope._doTrainThenResolveModal());
            } else {
                $scope._doTrainThenResolveModal();
            }
            return;
        }

        // For Keras the settings can be updated from inside the modal (if you activate the GPU training), so we save
        // first before training
        $scope.saveSettings().then(function() {
            if ($scope.recipe.params && $scope.recipe.params.skipPrerunValidate) {
                $scope._doTrainThenResolveModal();
            } else {
                $scope.validateRecipe().then(function(validationResult) {
                    if (!validationResult.topLevelMessages || !validationResult.topLevelMessages.maxSeverity || validationResult.topLevelMessages.maxSeverity === 'OK') {
                        $scope._doTrainThenResolveModal();
                    } else {
                        $state.go('projects.project.analyses.analysis.ml.predmltask.list.design.classical-keras-build',
                            {
                                "projectKey": $scope.projectSummary.projectKey,
                                "analysisId": $scope.analysisCoreParams.id,
                                "mlTaskId": $scope.mlTaskDesign.id
                            }).then(() => {
                                $scope.valCtx.showPreRunValidationError = true;
                                $scope.dismiss();
                        });

                    }
                });
            }
        });
    }

    $scope.uiState.selectedAlgorithmsWithWeightIncompatibility = [];
    if ($scope.isSampleWeightEnabled()) {
        $scope.base_algorithms[$scope.mlTaskDesign.backendType].forEach(function(x) {
            const unsupportedAlgo = $scope.algosWithoutWeightSupport && $scope.algosWithoutWeightSupport.has(x.algKey);

            if($scope.mlTaskDesign.modeling[x.algKey] && $scope.mlTaskDesign.modeling[x.algKey].enabled && unsupportedAlgo) {
                $scope.uiState.selectedAlgorithmsWithWeightIncompatibility.push(x.name);
            }

            // Looking at plugin algorithms as well
            const algoInPluginsAndEnabled = $scope.mlTaskDesign.modeling["plugin_python"]
                && $scope.mlTaskDesign.modeling["plugin_python"][x.algKey]
                && $scope.mlTaskDesign.modeling["plugin_python"][x.algKey].enabled;

            if (algoInPluginsAndEnabled && unsupportedAlgo) {
                $scope.uiState.selectedAlgorithmsWithWeightIncompatibility.push(x.name);
            }
        });
    }
});

app.controller("KerasPMLTaskPreTrain", function ($scope, $controller) {
    $controller("_ClassicalPMLTaskPreTrainBase", { $scope });
    $scope.forceKerasTrain = function () {
        $scope._doTrain().then(function () {
            $scope.refreshMLTaskSessions();
        });
    };
});

app.directive("mlParamWithFilterCard", function () {
    return {
        scope: {
            it: '=',
            delete: '&',
            recipeVariables: '<',
            postScriptFeaturesSchema: '<',
            sortable: '<'
        },
        transclude: true,
        templateUrl: '/templates/analysis/mlcommon/settings/ml-param-with-filter-card.html'
    }
})

app.component("mlRangeCondition", {
    bindings: {
            minValue: '=',
            maxValue: '=',
        },
        templateUrl: '/templates/analysis/mlcommon/settings/ml-range-condition.html',
        controller: function() {
            const $ctrl = this;
            $ctrl.isInvalidMinMax = function() {
                if (angular.isUndefined($ctrl.minValue) || angular.isUndefined($ctrl.maxValue)) {
                    return false;
                }
                return $ctrl.minValue >= $ctrl.maxValue;
            };
        }
  });

app.component("mlAssertionCondition",  {
    bindings: {
        condition: '=',
        isClassification: '&',
        classes: '<',
        percentage: '<'
    },
    templateUrl: '/templates/analysis/mlcommon/settings/ml-assertion-condition.html',
    controller: function() {
        const $ctrl = this;
    }
});

app.component("mlOverrideCondition", {
    bindings: {
        outcome: '=',
        targetVariable: '<',
        isClassification: '<',
        classes: '<'
    },
    templateUrl: '/templates/analysis/mlcommon/settings/ml-override-condition.html',
    controller: function() {
        const $ctrl = this;

        $ctrl.$onInit = function() {
            $ctrl.policies = {"DECLINED": "Declined"};
            if ($ctrl.isClassification) {
                $ctrl.policies["CATEGORY"] = "Enforced Category";
            }
            else {
                $ctrl.policies["INTERVAL"] = "Enforced Interval";
            }
        };
    }
});

app.service("MlParamsWithFilterService", function (DataikuAPI, StringUtils, Logger) {
    const svc = this;

    svc.prepare = function (scope, projectKey, analysisId, extraColumns=[]) {
        scope.isClassification = () => {
            return scope.mlTaskDesign && scope.mlTaskDesign.predictionType !== "REGRESSION";
        }

        DataikuAPI.analysis.getPostScriptSchema(projectKey, analysisId).success(function (data) {
            scope.postScriptFeaturesSchema = data;
            scope.postScriptFeaturesSchema.columns.splice(scope.postScriptFeaturesSchema.columns.map(column => column.name).indexOf(scope.mlTaskDesign.targetVariable), 1);
            svc.enrichPostScriptSchemaWithExtraColumns(scope.postScriptFeaturesSchema, extraColumns);
        });
        scope.classes = scope.mlTaskDesign.preprocessing.target_remapping.map(clss => clss.sourceValue);
    }

    svc.enrichPostScriptSchemaWithExtraColumns = function(postScriptSchema, extraColumns) {
        if (!extraColumns || !extraColumns.length) {
            return;
        }
        // If we have extra columns:
        //  * We put them at the beginning of the list to highlight them
        //  * We create a distinction between them and the original features by adding a group field used by group by
        const group_computed = 'Computed'
        const group_features = 'Features'
        const extraColumnsWithGroup = extraColumns.map(col => ({...col, $$groupKey: group_computed}));
        const featuresWithGroup = postScriptSchema.columns.map(col => ({...col, $$groupKey: group_features}));
        postScriptSchema.columns = extraColumnsWithGroup.concat(featuresWithGroup);
        postScriptSchema.columns.sort((a, b) => {
            if (a.$$groupKey === group_computed && b.$$groupKey !== group_computed) return -1;
            if (b.$$groupKey === group_computed && a.$$groupKey !== group_computed) return 1;
            return 0;
        });
    }

    svc.columnAnalysis = function (scope, projectKey) {
        return DataikuAPI.shakers.detailedColumnAnalysis(
            projectKey,
            scope.analysisCoreParams.projectKey,
            scope.analysisCoreParams.inputDatasetSmartName,
            scope.analysisCoreParams.script,
            null,
            scope.mlTaskDesign.targetVariable,
            50
        )
    }

    svc.createAddNewParamWithFilterCallback = function(paramName, paramsWithFilter, isClassification, targetColumnAnalysisCallback, setupDefaultClassifParamCallback, setupDefaultRegressionParamCallback) {

        // Put arbitrary values by default if callback fails to execute
        let minValueTarget = 0;
        let maxValueTarget = 100;
        if (!isClassification) {
            targetColumnAnalysisCallback().success(function(data) {
                minValueTarget = data.numericalAnalysis.min;
                maxValueTarget = data.numericalAnalysis.max;
            }).error(function(error) {
                Logger.error("Failed to compute min and max of target variable. Putting arbitrary values for override", error);
            });
        }
        return function() {
            let newParamWithFilter = {
                filter: {"enabled": true},
                name: StringUtils.transmogrify(`${paramName} ${paramsWithFilter.length + 1}`,
                    paramsWithFilter.map(a => a.name),
                    function (i) {
                        return `${paramName} ${i + 1}`;
                    }),
                outcome: {}
            };

            if (isClassification) {
                setupDefaultClassifParamCallback(newParamWithFilter);
            } else {
                setupDefaultRegressionParamCallback(newParamWithFilter, minValueTarget, maxValueTarget);
            }
            paramsWithFilter.push(newParamWithFilter);
        }
    }

    svc.createAddNewOverrideParamCallback = function(overrides, isClassification, targetColumnAnalysisCallback, classes) {
        return svc.createAddNewParamWithFilterCallback(
            "Override",
            overrides,
            isClassification,
            targetColumnAnalysisCallback,
            (newOverride) => {
                newOverride.outcome = {
                    type: "CATEGORY",
                    category: classes[0]
                };
            },
            (newOverride, minValueTarget, maxValueTarget) => {
                newOverride.outcome = {
                    type: "INTERVAL",
                    minValue: Math.round(minValueTarget),
                    maxValue: Math.round(maxValueTarget)
                };
            }
        );
    }
});


app.controller("PMLTaskAssertionsController", function ($scope, StringUtils, $stateParams, MlParamsWithFilterService) {
    MlParamsWithFilterService.prepare($scope, $stateParams.projectKey, $stateParams.analysisId);

    $scope.addNewMlAssertion = MlParamsWithFilterService.createAddNewParamWithFilterCallback(
        "Assertion",
        $scope.mlTaskDesign.assertionsParams && $scope.mlTaskDesign.assertionsParams.assertions,
        $scope.isClassification(),
        () => MlParamsWithFilterService.columnAnalysis($scope, $stateParams.projectKey),
        (newAssertion) => {
            newAssertion.assertionCondition = {
                expectedValidRatio: 0.9,
                expectedClass: $scope.classes[0]
            };
        },
        (newAssertion, minValueTarget, maxValueTarget) => {
            newAssertion.assertionCondition = {
                expectedValidRatio: 0.9,
                expectedMinValue: Math.round(minValueTarget),
                expectedMaxValue: Math.round(maxValueTarget)
            };
        }
    );
});

app.component("modelOverridesInfoBox", {
    bindings: {
        "isClassification": "<"
    },
    templateUrl: "/templates/analysis/prediction/settings/model-overrides-info-box.html"
});

app.controller("PMLModelOverridesController", function ($scope, $stateParams, MlParamsWithFilterService, OverridesExtraColumnsService) {

    // Prediction is not selectable in binary classification because it does not make sense, i.e. it is almost always possible to get
    // the same outcome without prediction in the rule (unless when you revert the prediction, which is dumb), and it
    // would considerably complexify the results.

    const extraColumns = OverridesExtraColumnsService.getExtraColumnsFromMLTaskDesign($scope.mlTaskDesign);
    MlParamsWithFilterService.prepare($scope, $stateParams.projectKey, $stateParams.analysisId, extraColumns);

    $scope.getOverridesNotAvailableReason = function () {
        if (typeof $scope.mlTaskDesign.partitionedModel !== 'undefined' && $scope.mlTaskDesign.partitionedModel.enabled){
            return "Overrides are not available for partitioned models"
        }
        if ($scope.mlTaskDesign.taskType === 'CLUSTERING'){
            return "Overrides are not available for clustering";
        }
        if ($scope.isMLBackendType('KERAS')) {
            return "Overrides are not available for KERAS models";
        }
        if ($scope.isMLBackendType('DEEP_HUB')) {
            return "Overrides are not available for computer vision models";
        }
    };

    $scope.addNewMlOverride = MlParamsWithFilterService.createAddNewOverrideParamCallback(
        $scope.mlTaskDesign.overridesParams.overrides,
        $scope.isClassification(),
        () => MlParamsWithFilterService.columnAnalysis($scope, $stateParams.projectKey),
        $scope.classes
    );
})

app.controller('_K8sConfigurationCheckerController', ($scope, $stateParams, DataikuAPI) => {
    let k8sContainerNames = [];
    let defaultContainerName = null;

    DataikuAPI.containers.listNamesWithDefault($stateParams.projectKey, 'KUBERNETES', "USER_CODE")
        .success((data) => {
            k8sContainerNames = data.containerNames;
            defaultContainerName = data.resolvedInheritValue;
        })
        .error(setErrorInScope.bind($scope));

    $scope.isK8sContainer = (backendType, containerSelection) => {
        if (!['PY_MEMORY', 'KERAS'].includes(backendType)) {
            return false;
        }

        switch (containerSelection.containerMode) {
            case 'EXPLICIT_CONTAINER':
                return k8sContainerNames.includes(containerSelection.containerConf);
            case 'INHERIT':
                return k8sContainerNames.includes(defaultContainerName);
            default:
                return false;
        }
    };
});

app.controller("_PMLTaskWithK8sContainerInformationController", ($scope, $controller) => {
    $controller("_K8sConfigurationCheckerController", { $scope });

    $scope.hasSelectedK8sContainer = () => {
        const { backendType, containerSelection } = $scope.mlTaskDesign;
        return $scope.isK8sContainer(backendType, containerSelection);
    };
});

app.controller("PMLTaskHyperparametersController", ($scope, $controller, DataikuAPI, $stateParams,
    VisualMlCodeEnvCompatibility, Debounce) => {
    $controller("PMLTaskCrossvalController", { $scope });
    $controller("_PMLTaskWithK8sContainerInformationController", { $scope });

    $scope.$watch("mlTaskDesign.modeling.gridSearchParams.mode", function(nv, ov){
        if (nv === "CUSTOM" && !$scope.mlTaskDesign.modeling.gridSearchParams.code) {
            $scope.mlTaskDesign.modeling.gridSearchParams.code =
                "# Define an object named cv that follows the scikit-learn splitter protocol\n"+
                "# This example uses the 'repeated K-fold' splitter of scikit-learn\n"+
                "from sklearn.model_selection import RepeatedKFold\n"+
                "\n"+
                "cv = RepeatedKFold(n_splits=3, n_repeats=5)"
        }
    });

    $scope.onChangeSearchConstraints = Debounce().withDelay(0, 400).wrap(function(searchConstraint) {
        if (searchConstraint === null || searchConstraint === undefined) return;
        $scope.checkIfHPSearchIsNeeded();
    });
});


app.component("searchStrategyWarnings", {
    bindings: {
        "mlTaskDesign": "<",
        "codeEnvsCompat": "<",
        "runtimeEnvironmentUiSref": "@"
    },
    templateUrl: "/templates/analysis/prediction/settings/search-strategy-warnings.html",
    controller: function($scope, $rootScope, $stateParams, DataikuAPI, VisualMlCodeEnvCompatibility) {
        const ctrl = this;
        ctrl.showBayesianSearchWarning = showBayesianSearchWarning;

        let previousStrategy = null;
        ctrl.$onInit = () => {
            previousStrategy = ctrl.mlTaskDesign.modeling.gridSearchParams.strategy;
            ctrl.algosIncompatibleWithSearchStrategy = [];
            retrieveAlgosIncompatibleWithSearchStrategy();
        };

        ctrl.$doCheck = () => {
            if (!angular.equals(ctrl.mlTaskDesign.modeling.gridSearchParams.strategy, previousStrategy)) {
                retrieveAlgosIncompatibleWithSearchStrategy();
                previousStrategy = ctrl.mlTaskDesign.modeling.gridSearchParams.strategy;
            }
        };

        ///////////////////////////

        function showBayesianSearchWarning() {
            //This is safe because the check is redone frequently, so we can tolerate skipping while initializing ctrl
            if (!ctrl.mlTaskDesign || !ctrl.mlTaskDesign.envSelection || !ctrl.codeEnvsCompat) return false;

            const envCompat = VisualMlCodeEnvCompatibility.getCodeEnvCompat(ctrl.mlTaskDesign.envSelection, ctrl.codeEnvsCompat);
            const isCodeEnvCompatibleWithBayesian = envCompat && envCompat.bayesianSearch && envCompat.bayesianSearch.compatible;

            return !$rootScope.appConfig.isAutomation && ctrl.mlTaskDesign.modeling.gridSearchParams.strategy === 'BAYESIAN' && !isCodeEnvCompatibleWithBayesian;
        }

        function retrieveAlgosIncompatibleWithSearchStrategy() {
            // No need to check the backend type, because in this controller, we assume that hyperparameter search is available for the current backend
            if (!ctrl.mlTaskDesign) {
                return;
            }
            DataikuAPI.analysis.pml.listAlgosIncompatibleWithSearchStrategy($stateParams.projectKey, ctrl.mlTaskDesign).then((response) => {
                ctrl.algosIncompatibleWithSearchStrategy = Object.entries(response.data).map(([name, strategy]) => { return { 'name': name, 'strategy': strategy }; });
            }
            ).catch(setErrorInScope.bind($scope));
        }
    }
})

app.constant("DEEPHUB_PREDICTION_TYPE_CODE_ENV_TYPE", {
    "DEEP_HUB_IMAGE_OBJECT_DETECTION": "DEEP_HUB_IMAGE_OBJECT_DETECTION_CODE_ENV",
    "DEEP_HUB_IMAGE_CLASSIFICATION": "DEEP_HUB_IMAGE_CLASSIFICATION_CODE_ENV"
});

app.controller("PMLTaskRuntimeController", ($scope, $controller, DataikuAPI, DEEPHUB_PREDICTION_TYPE_CODE_ENV_TYPE) => {
    $controller("_PMLTaskWithK8sContainerInformationController", { $scope });

    const updateHpSearchDistribution = (newSelection, oldSelection) => {
        if (angular.equals(newSelection, oldSelection)) {
            return;
        }

        const searchParams = $scope.mlTaskDesign.modeling.gridSearchParams;
        if (!searchParams) { // deephub doesn't support HP search
            return;
        }
        searchParams.distributed = searchParams.distributed && $scope.hasSelectedK8sContainer();
    };

    $scope.$watch('mlTaskDesign.containerSelection', updateHpSearchDistribution, true);

    $scope.checkDeepHubCodeEnvCallback = function (scope) {
        scope.deepHubCodeEnv = undefined;
        scope.isDeepHubCodeEnvAvailable = () => scope.deepHubCodeEnv !== undefined;

        scope.getHumanReadablePredictionType = function () {
            switch ($scope.mlTaskDesign.predictionType) {
                case "DEEP_HUB_IMAGE_OBJECT_DETECTION": return "object detection";
                case "DEEP_HUB_IMAGE_CLASSIFICATION": return "image classification";
            }
            return $scope.mlTaskDesign.predictionType;
        }

        function checkDeepHubCodeEnv(codeEnvVersion) {
            DataikuAPI.codeenvs.checkDSSInternalCodeEnv(
                DEEPHUB_PREDICTION_TYPE_CODE_ENV_TYPE[$scope.mlTaskDesign.predictionType],
                codeEnvVersion)
                .then(function ({data}) {
                    if (Object.keys(data).length > 0) {
                        scope.deepHubCodeEnv = data.value;
                    }
                })
                .catch(setErrorInScope.bind(scope));
        };


        const isDeepHubTask = $scope.mlTaskDesign.predictionType.startsWith('DEEP_HUB_');

        scope.showDeepHubCodeEnvWarning = function () {
            return isDeepHubTask && (
                scope.envSelection.envMode !== "EXPLICIT_ENV" ||
                (scope.isDeepHubCodeEnvAvailable() && scope.envSelection.envName !== scope.deepHubCodeEnv.envName) ||
                !scope.isDeepHubCodeEnvAvailable()
            );
        };

        if(isDeepHubTask) {
            checkDeepHubCodeEnv();
        }

    }

    $scope.availableGpuCapabilities = $scope.getAvailableGpuCapabilities(); // everything possible given the task
    $scope.usedGpuCapabilities = $scope.getUsedGpuCapabilities(); // everything actually present on the task (e.g. is text embedding in use, is xgboost on)
    $scope.usedGpuCapabilitiesWithGpuOn = $scope.getUsedGpuCapabilitiesWithGpuOn(); // everything present that isn't explicitly disabled in runtime config

    $scope.refreshEnabledGpuCapabilities = function() {
        $scope.usedGpuCapabilitiesWithGpuOn = $scope.getUsedGpuCapabilitiesWithGpuOn();
    }

    $scope.isProphetEnabled = function() {
        if (!$scope.mlTaskDesign || !$scope.mlTaskDesign.modeling) return false;
        return $scope.mlTaskDesign.modeling.prophet_timeseries && $scope.mlTaskDesign.modeling.prophet_timeseries.enabled;
    }
});

app.controller("PMLTaskCrossvalController", function($scope, $stateParams, DataikuAPI, DatasetUtils, Dialogs, SamplingData, TimeseriesForecastingUtils, $timeout, TimeseriesForecastingCustomTrainTestFoldsUtils){
    var datasetLoc = DatasetUtils.getLocFromSmart($stateParams.projectKey, $scope.analysisCoreParams.inputDatasetSmartName);
    DataikuAPI.datasets.get(datasetLoc.projectKey, datasetLoc.name, $stateParams.projectKey).success(function (data) {
        $scope.analysisDataset = data;
    }).error(setErrorInScope.bind($scope));

    let resamplingParams = $scope.mlTaskDesign.preprocessing.timeseriesSampling;
    
    $scope.propagateCustomStartDateUpdate = function() {
        $timeout(function() {
            $scope.mlTaskDesign.preprocessing.timeseriesSampling.customStartDate = $scope.customStartDate ? $scope.formatDateWithoutTimezone($scope.customStartDate) : undefined;
        });
    };
    $scope.propagateCustomEndDateUpdate = function() {
        $timeout(function () {
            $scope.mlTaskDesign.preprocessing.timeseriesSampling.customEndDate = $scope.customEndDate ? $scope.formatDateWithoutTimezone($scope.customEndDate) : undefined;
        });
    };
    
    if (resamplingParams != null) {
        $scope.customStartDate = resamplingParams.customStartDate && TimeseriesForecastingCustomTrainTestFoldsUtils.forceConvertToUTCTimezoneDate(resamplingParams.customStartDate) || new Date();
        $scope.customEndDate = resamplingParams.customEndDate && TimeseriesForecastingCustomTrainTestFoldsUtils.forceConvertToUTCTimezoneDate(resamplingParams.customEndDate) || new Date();
    }

    $scope.getPartitionsList = function () {
        return DataikuAPI.datasets.listPartitionsWithName(datasetLoc.projectKey, datasetLoc.name)
            .error(setErrorInScope.bind($scope))
            .then(function (ret) {
                return ret.data;
            })
    };

    $scope.checkIfHPSearchIsNeeded = function() {
        DataikuAPI.analysis.pml.isHPSearchNeeded($stateParams.projectKey, $scope.mlTaskDesign).then(function(response) {
            $scope.isHPSearchNeeded = response.data.isHPSearchNeeded;
        }).catch(setErrorInScope.bind($scope));
    };

    $scope.checkIfHPSearchIsNeeded();

    $scope.isTimeseriesForecastWithBothKFold = function () {
        if (!$scope.isTimeseriesPrediction()) return false;
        return $scope.isHPSearchNeeded
            && !$scope.mlTaskDesign.customTrainTestSplit
            && $scope.mlTaskDesign.splitParams.kfold
            && $scope.mlTaskDesign.modeling.gridSearchParams.mode === "TIME_SERIES_KFOLD";
    };

    $scope.prettyTimeSteps = TimeseriesForecastingUtils.prettyTimeSteps;

    // Prefill
    $scope.$watch("uiSplitParams.policy", function(nv, ov) {
        if (!nv) return;
        if (nv == "EXPLICIT_FILTERING_TWO_DATASETS") {
            if (!$scope.mlTaskDesign.splitParams.eftdTrain) {
                $scope.mlTaskDesign.splitParams.eftdTrain = {
                    datasetSmartName : $scope.analysisCoreParams.inputDatasetSmartName,
                    selection : DatasetUtils.makeHeadSelection(100000)
                }
            }
            if (!$scope.mlTaskDesign.splitParams.eftdTest) {
                $scope.mlTaskDesign.splitParams.eftdTest = {
                    selection : DatasetUtils.makeHeadSelection(100000)
                }
            }
        } else if (nv.indexOf("EXPLICIT_FILTERING_SINGLE_DATASET")==0) {
            if (!$scope.mlTaskDesign.splitParams.efsdTrain) {
                $scope.mlTaskDesign.splitParams.efsdTrain = {
                    selection : DatasetUtils.makeHeadSelection(100000)
                }
            }
            if (!$scope.mlTaskDesign.splitParams.efsdTest) {
                $scope.mlTaskDesign.splitParams.efsdTest = {
                    selection : DatasetUtils.makeHeadSelection(100000)
                }
            }
        } else if (nv == "SPLIT_OTHER_DATASET") {
            if (!$scope.mlTaskDesign.splitParams.ssdDatasetSmartName) {
                $scope.mlTaskDesign.splitParams.ssdDatasetSmartName = $scope.analysisCoreParams.inputDatasetSmartName;
            }
        }
        if (nv != "SPLIT_MAIN_DATASET" && $scope.mlTaskDesign.partitionedModel && $scope.mlTaskDesign.partitionedModel.enabled) {
            const choices = [
                { revert: false, title: "Disable partitioning & keep this policy",
                    desc: ($scope.trainTestPolicies.find(_ => _[0] === nv) || [_,nv])[1] },
                { revert: true, title: "Keep partitioning & revert policy",
                    desc: $scope.trainTestPolicies[0][1] }
            ];
            function act(choice) {
                if (choice.revert) {
                    $scope.uiSplitParams.policy = 'SPLIT_MAIN_DATASET';
                } else {
                    $scope.mlTaskDesign.partitionedModel.enabled = false;
                }
            }
            Dialogs.select($scope, "Change train/test policy",
                "Model partitioning is enabled, but not compatible with this policy.",
                choices, choices[0]
            ).then(act, act.bind(null, choices[1])); // dismiss => revert policy
        }
    });

    $scope.canCustomizeResamplingDates = function() {
        // keep in sync with Resampler._can_customize_resampling_dates()
        return $scope.mlTaskDesign.preprocessing.timeseriesSampling.numericalExtrapolateMethod !== 'NO_EXTRAPOLATION';
    };
    
    DatasetUtils.listDatasetsUsabilityForAny($stateParams.projectKey).success(function (data) {
        $scope.availableDatasets = data;
        $scope.availableDatasetsExceptForInputDataset = $scope.availableDatasets.filter(function(d) {
            return d.smartName !==  $scope.analysisCoreParams.inputDatasetSmartName
        })
        data.forEach(function (ds) {
            ds.usable = true;
        });
    });

    $scope.potentialTimeFeatures = function() {
        const per_feature = $scope.mlTaskDesign.preprocessing.per_feature;
        if ($scope.analysisDataset) {
            // Sort and split are done before Script is applied so only input columns can be time features
            const inputColumns = $scope.analysisDataset.schema.columns.map(col => col.name);
            return inputColumns.filter(col => (per_feature[col] && per_feature[col].role !== "TARGET"));
        }
    }

    $scope.getSamplingMethodLabel = function() {
        return SamplingData.getSamplingMethodForDocumentation($scope.mlTaskDesign.splitParams.ssdSelection);
    }

    $scope.getCrossValidationLabel = function() {
        let crossValidationLabel;

        if ($scope.mlTaskDesign.time && $scope.mlTaskDesign.time.enabled) {
            crossValidationLabel = $scope.getCrossvalModesWithTimeForDocumentation($scope.mlTaskDesign.modeling.gridSearchParams.mode, $scope.mlTaskDesign)
        } else {
            crossValidationLabel = $scope.getCrossvalModesRandomForDocumentation($scope.mlTaskDesign.modeling.gridSearchParams.mode, $scope.mlTaskDesign);
        }

        return crossValidationLabel;
    };

    $scope.getHyperparametersBarsMaxWidth = function() {
        return $scope.mlTaskDesign.splitParams.kfold ? 1 : $scope.mlTaskDesign.splitParams.ssdTrainingRatio;
    };

    $scope.deferredAfterInitMlTaskDesign.then(() => {
        $scope.groupKFoldColumnNames = Object.keys($scope.mlTaskDesign.preprocessing.per_feature).filter(function(name) {
            if ($scope.isCausalPrediction() && name === $scope.mlTaskDesign.treatmentVariable) {
                return false;
            }
            return $scope.mlTaskDesign.targetVariable !== name;
        });
    });
});

app.controller("PMLTaskFeatureSelectionController", function($scope, $controller, $timeout, $stateParams, DataikuAPI, Dialogs){
    $scope.featureSelectionKinds = [
        ["NONE", "No reduction"],
        ["CORRELATION", "Correlation with target"],
        ["RANDOM_FOREST", "Tree-based"],
        ["PCA", "Principal Component Analysis"],
    ]

    if(!$scope.isMulticlass() &&
        $scope.mlTasksContext &&
        $scope.mlTasksContext.activeMLTask &&
        $scope.mlTasksContext.activeMLTask.backendType == "PY_MEMORY") {
        $scope.featureSelectionKinds.push(["LASSO", "LASSO regression"]);
    }

    // Signal to Puppeteer that the content of the element has been loaded and is thus available for content extraction
    $scope.puppeteerHook_elementContentLoaded = true;
})

app.controller("_PMLTargetRemappingController", function($scope, Assert, Fn) {
    Assert.inScope($scope, 'mlTaskDesign');

    function updateGraph() {
        try {
            const totalCount = $scope.mlTaskDesign.preprocessing.target_remapping.map(Fn.prop("sampleFreq")).reduce(Fn.SUM);

            $scope.graphData = $scope.mlTaskDesign.preprocessing.target_remapping.map(function(x){
                return [x.sourceValue, x.sampleFreq / totalCount];
            });
        } catch (e) { /* Nothing for now */ }
    }

    $scope.$watch('mlTaskDesign.preprocessing.target_remapping', updateGraph, false); // shallow, for the "re-detect settings" case

    $scope.editMapping = {
        active: false,
        value : null
    }
    $scope.startEditMapping = function(){
        $scope.editMapping.active = true;
        $scope.editMapping.value = $scope.mlTaskDesign.preprocessing.target_remapping;
    }
    $scope.$watch("editMapping.value", function(nv, ov) {
        if (!nv) $scope.editMapping.error = true;
        else $scope.editMapping.error = false;
    }, true);
    $scope.okEditMapping = function(){
        $scope.editMapping.active = false;
        $scope.mlTaskDesign.preprocessing.target_remapping = $scope.editMapping.value ;
        updateGraph();
    }
    $scope.cancelEditMapping = function(){
        $scope.editMapping.active = false;
    }

    updateGraph();

    $scope.hasManyCategories = function(){
        return $scope.mlTaskDesign.preprocessing.target_remapping.length >= 50;
    }

});

app.controller("TabularPMLTargetRemappingController", function($scope, $controller) {
    $scope.colors = window.dkuColorPalettes.discrete[0].colors // adjascent colors are too similar
        .filter(function(c, i) { return i % 2 === 0; });        // take only even-ranked ones
    $controller("_PMLTargetRemappingController",{ $scope: $scope });
});

app.controller("PMLSparkConfigController", function($scope, Assert, DataikuAPI, Fn) {
    Assert.inScope($scope, 'mlTaskDesign');

    $scope.sparkConfs = ['default'];
    DataikuAPI.admin.getGeneralSettings().success(function(data){
        $scope.sparkConfs = data.sparkSettings.executionConfigs.map(Fn.prop('name'));
    });
});

app.directive('tensorboardDestroyHandler', function () {
    return {
        link: function ($scope, elem, attr) {
            elem.on('$destroy', function() {
                $scope.sessionTask.tensorboardStatus.isFrontendReady = false;
                $scope.sessionTask.tensorboardStatus.showIfFrontIsNotReady = true;
            });
        }
    }

});

app.controller('LightGBMHyperparametersController', function($scope) {
    const { lightgbm_regression, lightgbm_classification } = $scope.mlTaskDesign.modeling;
    $scope.hpSpace = $scope.isRegression() || $scope.isTimeseriesPrediction() ? lightgbm_regression : lightgbm_classification;
});

app.controller('DeepNeuralNetworkHyperparametersController', function($scope) {
    const { deep_neural_network_regression, deep_neural_network_classification } = $scope.mlTaskDesign.modeling;
    $scope.hpSpace = $scope.isRegression() ? deep_neural_network_regression : deep_neural_network_classification;
});


app.controller('DesignEvaluationMetricController', function($scope) {
    $scope.baseEvaluationMetrics = [];
    if ($scope.isBinaryClassification()) {
        $scope.baseEvaluationMetrics = $scope.bcEvaluationMetrics;
    } else if ($scope.isMulticlass()) {
        $scope.baseEvaluationMetrics = $scope.mcEvaluationMetrics;
    } else if ($scope.isRegression()) {
        $scope.baseEvaluationMetrics = $scope.regressionEvaluationMetrics;
    } else if ($scope.isTimeseriesPrediction()) {
        $scope.baseEvaluationMetrics = $scope.timeseriesEvaluationMetrics;
    }
});

})();
