""" Commands available from the doctor main kernel server.

To add a command, simple add a method.
Method starting by a _ are not exposed.

Arguments with default values are supported.
*args ,**kargs are not supported.

If one of your json parameter is a global in python, you
can suffix your parameter by an _ (e.g. input_)
"""

import inspect
import json
import logging
import sys
import time

import numpy as np
from sklearn.model_selection import train_test_split

from dataiku.base import remoterun
from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import safe_unicode_str
from dataiku.core import dkujson
from dataiku.core import doctor_constants
from dataiku.core import intercom
from dataiku.core.doctor_constants import TARGET_VARIABLE, TIME_VARIABLE
from dataiku.core.percentage_progress import PercentageProgress
from dataiku.core.saved_model import get_source_dss_version
from dataiku.doctor import step_constants
from dataiku.doctor import utils
from dataiku.doctor.clustering.clustering_scorer import ClusteringModelScorer
from dataiku.doctor.clustering_entrypoints import clustering_train_score_save
from dataiku.doctor.deep_learning.keras_utils import tag_special_features
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.learning_curve import minimal_train_get_train_test_score
from dataiku.doctor.notebook_builder import ClusteringNotebookBuilder
from dataiku.doctor.notebook_builder import PredictionNotebookBuilder
from dataiku.doctor.posttraining import global_explanations
from dataiku.doctor.posttraining import individual_explanations
from dataiku.doctor.posttraining import partial_depency as pdp
from dataiku.doctor.posttraining import subpopulation as subpopulation
from dataiku.doctor.prediction.classification_fit import classification_fit
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer
from dataiku.doctor.prediction.classification_scoring import ClassificationModelIntrinsicScorer
from dataiku.doctor.prediction.classification_scoring import MulticlassModelScorer
from dataiku.doctor.prediction.classification_scoring import \
    compute_assertions_and_overrides_for_classification_from_clf
from dataiku.doctor.prediction.common import check_classical_prediction_type
from dataiku.doctor.prediction.common import get_monotonic_cst
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.prediction.common import prepare_multiframe
from dataiku.doctor.prediction.common import regridify_optimized_params
from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
from dataiku.doctor.prediction.metric import MAPE
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.model_io import to_pkl, from_pkl
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.overrides.ml_overrides_params import MlOverridesParams
from dataiku.doctor.prediction.overrides.ml_overrides_params import ml_overrides_params_from_model_folder
from dataiku.doctor.prediction.prediction_model_serialization import ModelSerializer
from dataiku.doctor.prediction.regression_fit import regression_fit_single
from dataiku.doctor.prediction.regression_scoring import RegressionModelIntrinsicScorer
from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer
from dataiku.doctor.prediction.regression_scoring import compute_assertions_and_overrides_for_regression_from_clf
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.prediction_entrypoints import prediction_train_model_keras
from dataiku.doctor.prediction_entrypoints import prediction_train_model_kfold
from dataiku.doctor.prediction_entrypoints import prediction_train_score_save
from dataiku.doctor.preprocessing_collector import ClusteringPreprocessingDataCollector
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.preprocessing_handler import ClusteringPreprocessingHandler
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.model_io import load_model_from_folder
from dataiku.doctor.utils.split import sort_dataframe, check_train_test_order
from dataiku.doctor.utils.split import df_from_split_desc
from dataiku.doctor.utils.split import df_from_split_desc_no_normalization
from dataiku.doctor.utils.split import load_test_set
from dataiku.doctor.utils.split import load_train_set
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult

preprocessing_listener = ProgressListener()
global_modeling_sets = []

logger = logging.getLogger(__name__)

if sys.version_info >= (3,):
    # timeseries package does not support 2.7
    from dataiku.doctor.timeseries.train.training_handler import resample_for_training
    from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
    from dataiku.doctor.timeseries.preparation.preprocessing import TimeseriesPreprocessing, get_windows_list, \
        add_rolling_windows_for_training
    from dataiku.doctor.timeseries.posttraining.permutation_importance import compute_post_train_permutation_importance

def _list_commands():
    current_module = sys.modules[__name__]
    return [
        (func_name, func)
        for (func_name, func) in current_module.__dict__.items()
        if not func_name.startswith("_") and inspect.isfunction(func) and inspect.getmodule(func) == current_module
    ]


def create_prediction_notebook(model_name, model_date, dataset_smartname,
                               script, preparation_output_schema,
                               split_stuff,
                               core_params,
                               preprocessing_params,
                               pre_train, post_train, kernel_name, kernel_display_name):

    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)
    return json.dumps(PredictionNotebookBuilder(model_name, model_date, dataset_smartname,
                                                script["steps"], preparation_output_schema,
                                                split_stuff,
                                                core_params,
                                                preprocessing_params,
                                                pre_train, post_train, kernel_name, kernel_display_name).create_notebook())


def create_clustering_notebook(model_name, model_date, dataset_smartname,
                               script, preparation_output_schema,
                               split_stuff,
                               preprocessing_params,
                               pre_train, post_train, kernel_name, kernel_display_name):
    return json.dumps(ClusteringNotebookBuilder(model_name, model_date, dataset_smartname,
                                                script["steps"], preparation_output_schema,
                                                split_stuff,
                                                preprocessing_params,
                                                pre_train, post_train, kernel_name, kernel_display_name).create_notebook())


def train_prediction_kfold(core_params, preprocessing_set, split_desc, ml_overrides_params):
    log_nvidia_smi_if_use_gpu(core_params=core_params)
    if core_params.get("time").get("enabled", False):
        raise ValueError("Training with k-fold cross-test is not compatible with time ordering of data")
    ml_overrides_params = MlOverridesParams.from_dict(ml_overrides_params['value'])
    model_type = core_params["taskType"]
    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)
    build_folder_contexts(preprocessing_set)
    default_diagnostics.register_prediction_callbacks(core_params)
    start = unix_time_millis()
    preprocessing_params = preprocessing_set['preprocessing_params']
    modeling_sets = preprocessing_set["modelingSets"]
    assertions_params_list = preprocessing_set.get("assertionsParams", {}).get("assertions", None)

    logger.info("PPS is %s" % preprocessing_params)
    split_folder_context = preprocessing_set["split_folder_context"]
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]
    preprocessing_listener = ProgressListener()
    preprocessing_listener.add_future_steps(step_constants.PRED_KFOLD_PREPROCESSING_STEPS)
    nan_support = PredictionAlgorithmNaNSupport(modeling_sets[0]["modelingParams"], preprocessing_params)

    for modeling_set in modeling_sets:
        listener = preprocessing_listener.new_child(ModelStatusContext(modeling_set["model_folder_context"], start))
        listener.add_future_steps(step_constants.PRED_KFOLD_TRAIN_STEPS)
        modeling_set["listener"] = listener

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_SRC):
        full_df = load_train_set(core_params, preprocessing_params, split_desc, "full", split_folder_context,
                                 assertions=assertions_params_list)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
        collector = PredictionPreprocessingDataCollector(full_df, preprocessing_params)
        collector_data = collector.build()

        pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, preprocessing_folder_context,
                                                               preprocessing_params, assertions=assertions_params_list,
                                                               nan_support=nan_support)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_PREPROCESS_GLOBAL):
        transformed_full = pipeline.fit_and_process(full_df)
        diagnostics.on_kfold_step_preprocess_global_end(multiframe=transformed_full['TRAIN'])
        preproc_handler.save_data()
        preproc_handler.report(pipeline)

    preprocessing_listener.save_status()
    preprocessing_end = unix_time_millis()

    train_X = transformed_full["TRAIN"]
    train_y = transformed_full["target"]

    weight_method = core_params.get("weight", {}).get("weightMethod", None)
    with_sample_weight = weight_method in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    with_class_weight = weight_method in {"CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    if with_sample_weight:
        assert transformed_full["weight"].values.min() > 0, "Sample weights must be positive"

    calibration_method = core_params.get("calibration", {}).get("calibrationMethod")
    calibration_ratio = core_params.get("calibration", {}).get("calibrationDataRatio", doctor_constants.DEFAULT_CALIBRATION_DATA_RATIO)
    calibrate_proba = calibration_method in [doctor_constants.SIGMOID, doctor_constants.ISOTONIC]

    monotonic_cst = get_monotonic_cst(preprocessing_params, train_X)

    for modeling_set in modeling_sets:
        model_start = unix_time_millis()
        listener = modeling_set["listener"]
        previous_search_time = utils.get_hyperparams_search_time_traininfo(modeling_set["model_folder_context"])
        assertions_metrics = None
        overrides_metrics = None
        if core_params["prediction_type"] in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
            with listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_FITTING_GLOBAL, previous_duration=previous_search_time):
                # no out-fold available, so calibrate through classification_fit on a random split
                (clf, actual_params, prepared_X, iipd) = classification_fit(modeling_set['modelingParams'],
                                                                            core_params,
                                                                            transformed_full,
                                                                            model_folder_context=modeling_set['model_folder_context'],
                                                                            target_map=preproc_handler.target_map,
                                                                            with_sample_weight=with_sample_weight,
                                                                            with_class_weight=with_class_weight,
                                                                            calibration_method=calibration_method,
                                                                            calibration_ratio=calibration_ratio,
                                                                            calibrate_on_test=False,  # Training is done on transformed_full, so no test set is left
                                                                            monotonic_cst=monotonic_cst)
                diagnostics.on_fitting_end(prediction_type=prediction_type, clf=clf, train_target=train_y, features=train_X.columns())
                model = ScorableModel.build(clf, model_type, prediction_type, modeling_set["modelingParams"]['algorithm'],
                                            preprocessing_params, ml_overrides_params)
            with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
                ModelSerializer.build(model, modeling_set["model_folder_context"], train_X.columns(), calibrate_proba).serialize()
                modeling_set['model_folder_context'].write_json("actual_params.json", actual_params)

            with listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_SCORING_GLOBAL):
                logger.info("Running intrinsic scoring")
                ClassificationModelIntrinsicScorer(modeling_set['modelingParams'], clf, train_X, train_y, preproc_handler.target_map,
                                                   pipeline, modeling_set['model_folder_context'], prepared_X,
                                                   iipd, with_sample_weight, calibrate_proba).score()

                if "assertions" in transformed_full or ml_overrides_params is not None:
                    assertions_metrics, overrides_metrics = compute_assertions_and_overrides_for_classification_from_clf(clf,
                                                                                                                         model_type,
                                                                                                                         modeling_set["modelingParams"],
                                                                                                                         prediction_type,
                                                                                                                         preprocessing_params,
                                                                                                                         preproc_handler.target_map,
                                                                                                                         transformed_full,
                                                                                                                         ml_overrides_params)

        else:
            with listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_FITTING_GLOBAL, previous_duration=previous_search_time):
                (clf, actual_params, prepared_X, iipd) = regression_fit_single(modeling_set['modelingParams'],
                                                                               core_params,
                                                                               transformed_full,
                                                                               model_folder_context=modeling_set["model_folder_context"],
                                                                               with_sample_weight=with_sample_weight,
                                                                               monotonic_cst=monotonic_cst)
                diagnostics.on_fitting_end(features=train_X.columns(), clf=clf, prediction_type=prediction_type, train_target=transformed_full["target"])
                model = ScorableModel.build(clf, model_type, prediction_type, modeling_set["modelingParams"]['algorithm'],
                                            preprocessing_params, ml_overrides_params)
            with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
                ModelSerializer.build(model, modeling_set["model_folder_context"], train_X.columns()).serialize()
                modeling_set['model_folder_context'].write_json("actual_params.json", actual_params)

            with listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_SCORING_GLOBAL):
                logger.info("Running intrinsic scoring")
                RegressionModelIntrinsicScorer(modeling_set['modelingParams'], clf, train_X, train_y, pipeline,
                                               modeling_set['model_folder_context'], prepared_X, iipd, with_sample_weight).score()

                if "assertions" in transformed_full or ml_overrides_params is not None:
                    assertions_metrics, overrides_metrics = compute_assertions_and_overrides_for_regression_from_clf(clf, model_type, modeling_set["modelingParams"],
                                                                                                                     transformed_full, ml_overrides_params)

        logger.info("Loading full dataframe")
        full_df_clean = df_from_split_desc(split_desc, "full", split_folder_context,
                                           preprocessing_params["per_feature"], prediction_type)
        optimized_params = actual_params["resolved"]

        logger.info("Regridifying post-train params: %s" % dkujson.dumps(optimized_params))

        # Regridify to a unary grid the optimized params
        optimized_params_grid = intercom.backend_json_call("ml/prediction/regridify-to-pretrain", {
            "preTrain": dkujson.dumps(modeling_set["modelingParams"]),
            "postTrain": dkujson.dumps(optimized_params)
        })
        logger.info("Using unary grid params: %s" % dkujson.dumps(optimized_params_grid))
        prediction_train_model_kfold(full_df_clean, core_params, split_desc, preprocessing_params, modeling_set["modelingParams"],
                                     optimized_params_grid, preprocessing_folder_context,
                                     modeling_set['model_folder_context'], split_folder_context,
                                     listener, with_sample_weight, with_class_weight,
                                     transformed_full,
                                     assertions_metrics=assertions_metrics,
                                     ml_overrides_params=ml_overrides_params,
                                     overrides_metrics=overrides_metrics,
                                     monotonic_cst=monotonic_cst)

        end = unix_time_millis()
        listeners_json = preprocessing_listener.merge(listener)
        utils.write_done_traininfo(modeling_set['model_folder_context'], start, model_start, end,
                                   listeners_json,
                                   end_preprocessing_time=preprocessing_end)

        return "ok"


def train_prediction_models_nosave(core_params, preprocessing_set, split_desc, ml_overrides_params):
    """Regular (mode 1) train:
      - Non streamed single split + fit preprocess on train + preprocess test
      - Fit N models sequentially
         - Fit
         - Save clf
         - Compute and save clf performance
         - Score, save scored test set + scored performnace
    """
    start = unix_time_millis()
    log_nvidia_smi_if_use_gpu(core_params=core_params)
    check_classical_prediction_type(core_params["prediction_type"])
    build_folder_contexts(preprocessing_set)
    ml_overrides_params = MlOverridesParams.from_dict(ml_overrides_params["value"])
    default_diagnostics.register_prediction_callbacks(core_params)
    preprocessing_params = preprocessing_set["preprocessing_params"]
    modeling_sets = preprocessing_set["modelingSets"]
    assertions_params_list = preprocessing_set.get("assertionsParams", {}).get("assertions", None)
    nan_support = PredictionAlgorithmNaNSupport(modeling_sets[0]["modelingParams"], preprocessing_params)

    logger.info("PPS is %s" % preprocessing_params)
    split_folder_context = preprocessing_set["split_folder_context"]
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]
    preprocessing_listener = ProgressListener()
    # Fill all the listeners ASAP to have correct progress data
    preprocessing_listener.add_future_steps(step_constants.PRED_REGULAR_PREPROCESSING_STEPS)
    with_hyperparameter_search = False
    for modeling_set in modeling_sets:
        listener = preprocessing_listener.new_child(ModelStatusContext(modeling_set["model_folder_context"], start))
        if needs_hyperparameter_search(modeling_set.get('modelingParams', {})):
            listener.add_future_step(step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING)
            with_hyperparameter_search = True
        listener.add_future_steps(step_constants.PRED_REGULAR_TRAIN_STEPS)
        modeling_set["listener"] = listener

    sort = core_params.get("time", {}).get("enabled", False)
    time_variable = core_params.get("time", {}).get("timeVariable")
    ascending = core_params.get("time", {}).get("ascending", True)
    if sort and time_variable is None:
        raise ValueError("Time ordering is enabled but no time variable is specified")

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
        train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context)

        for col in train_df:
            logger.info("Train col : %s (%s)" % (col, train_df[col].dtype))

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
        test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context,
                                assertions=assertions_params_list)

    if sort:
        check_train_test_order(train_df, test_df, time_variable, ascending)
        if with_hyperparameter_search:
            logger.info(u"Hyperparameter search enabled: checking that the train set is sorted by '{}'".format(safe_unicode_str(time_variable)))
            sort_dataframe(train_df, time_variable, ascending)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
        collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
        collector_data = collector.build()

        pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, preprocessing_folder_context,
                                                               preprocessing_params,
                                                               assertions=assertions_params_list,
                                                               nan_support=nan_support)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
        # TODO: fit_and_process should take an update_fn argument
        transformed_train = pipeline.fit_and_process(train_df)
        diagnostics.on_preprocess_train_dataset_end(multiframe=transformed_train['TRAIN'])
        preproc_handler.save_data()
        preproc_handler.report(pipeline)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
        test_df_index = test_df.index.copy()
        transformed_test = pipeline.process(test_df)
        diagnostics.on_preprocess_test_dataset_end(multiframe=transformed_test['TRAIN'])  # even though this is test data, field is still called 'TRAIN'

    preprocessing_listener.save_status()
    preprocessing_end = unix_time_millis()

    for modeling_set in modeling_sets:
        model_start = unix_time_millis()

        # since ensembles are never fitted through the doctor, no need to distinguish here
        prediction_train_score_save(transformed_train,
                                    transformed_test, test_df_index,
                                    core_params, split_desc,
                                    modeling_set["modelingParams"],
                                    modeling_set["model_folder_context"],
                                    preprocessing_folder_context,
                                    split_folder_context,
                                    modeling_set["listener"],
                                    preproc_handler.target_map,
                                    pipeline,
                                    preprocessing_params,
                                    ml_overrides_params)

        end = unix_time_millis()

        listeners_json = preprocessing_listener.merge(modeling_set["listener"])
        utils.write_done_traininfo(modeling_set["model_folder_context"], start, model_start, end,
                                   listeners_json,
                                   end_preprocessing_time=preprocessing_end)

    return "ok"


def compute_learning_curve(model_folder, preprocessing_folder, split_desc, split_folder, fmi, postcompute_folder, core_params, computation_parameters, job_id=None, modellike_folder=None):
    logger.info("Learning Curves: starting compute")

    if computation_parameters is None or (
            "learning_curve_number_of_points" not in computation_parameters
    ):
        raise Exception("'computation_parameters' should contains a key 'learning_curve_number_of_points'")
    number_of_points = computation_parameters["learning_curve_number_of_points"]
    if (not isinstance(number_of_points, int)) or number_of_points < 1:
        raise Exception("'learning_curve_number_of_points' should be an integer greater or equal to 1. got " + str(number_of_points))

    progress = PercentageProgress(job_id)
    model_folder_context = build_folder_context(model_folder)
    split_folder_context = build_folder_context(split_folder)
    preprocessing_folder_context = build_folder_context(preprocessing_folder)
    postcompute_folder_context = build_folder_context(postcompute_folder)

    resolved_params = model_folder_context.read_json("actual_params.json")["resolved"]
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    regridified_modeling_params = regridify_optimized_params(resolved_params, modeling_params)
    preprocessing_params = preprocessing_folder_context.read_json("rpreprocessing_params.json")  # within pp1/m1 ... what ?
    ml_overrides_params = ml_overrides_params_from_model_folder(model_folder_context)
    nan_support = PredictionAlgorithmNaNSupport(modeling_params, preprocessing_params, source_dss_version=get_source_dss_version(model_folder_context))

    split_params = split_desc["params"]
    n_folds = None

    # may be a better way to check for this
    if "fullPath" in split_desc:
        # sc-151406: I do consider a kfold model as a simple train test model here!! Will properly be handled
        full_df = df_from_split_desc_no_normalization(split_desc=split_desc, split="full", split_folder_context=split_folder_context, feature_params=preprocessing_params["per_feature"], prediction_type=core_params["prediction_type"])
        collector = PredictionPreprocessingDataCollector(full_df, preprocessing_params)

        seed = 1337
        if split_params is not None:
            if split_params["ssdSeed"] is not None:
                seed = int(split_params["ssdSeed"])
            if split_params["kfold"] is True and split_params["nFolds"] is not None:
                n_folds = int(split_params["nFolds"])
        split_ratio = 1.0/n_folds if n_folds is not None else 0.2
        train_df, test_df = train_test_split(full_df, test_size=split_ratio, random_state=seed)
    else:
        train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context, use_diagnostics=False)
        test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context, use_diagnostics=False)
        collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
    test_df_index = test_df.index.copy()
    collector_data = collector.build()
    pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, preprocessing_folder_context,
                                                           preprocessing_params, nan_support=nan_support)
    target_map = preproc_handler.target_map
    transformed_test = pipeline.process(test_df)

    train_sizes = np.linspace(0, len(train_df), number_of_points+1, dtype=int)[1:]  # this allows to not get 0 as train size
    learning_curve_performances = []
    for count, train_size in enumerate(train_sizes, start=1):
        train_ratio = train_size / len(train_df)
        logger.info("Learning Curve : computing point with " + str(train_ratio * 100) + "% of initial train set")
        if train_ratio > 1:
            logger.warning("Learning Curve : train size bigger than initial train size. skipping point")
            continue
        train_df_frac = train_df.sample(n=train_size, random_state=42)
        try:
            train_df_frac_index = train_df_frac.index.copy()
            transformed_train_frac = pipeline.process(train_df_frac)
            start_time = time.time()
            perf = minimal_train_get_train_test_score(transformed_train_frac,
                                                      transformed_test,
                                                      train_df_frac_index,
                                                      test_df_index,
                                                      core_params,
                                                      regridified_modeling_params,
                                                      model_folder_context,
                                                      target_map,
                                                      preprocessing_params,
                                                      ml_overrides_params,
                                                      )
            end_time = time.time()
            learning_curve_performances.append({
                "train_size": len(train_df_frac),
                "test_size": len(test_df),
                "train_time": end_time - start_time,
                "train_metrics": perf["train_metrics"],
                "test_metrics": perf["test_metrics"],
            })
        except Exception as e:
            logger.error("Learning Curve : error training on " + str(train_ratio * 100) + "% of initial train set:\n" + str(e))

        current_percentage = 100 * count / len(train_sizes)  # train sizes is > 0 because we are in the loop
        progress.set_percentage(current_percentage)
    ret = {
        "perf_points": learning_curve_performances,
    }
    postcompute_folder_context.write_json("learning_curve.json", ret)
    logger.info("Learning Curves : computation finished")
    return "ok"


def build_pipeline_and_handler(collector_data, core_params, preprocessing_folder_context, preprocessing_params, assertions=None,
                               allow_empty_mf=False, nan_support=None):
    preproc_handler = build_preprocessing_handler(collector_data, core_params, preprocessing_folder_context, preprocessing_params, assertions,
                                                  nan_support=nan_support)
    pipeline = preproc_handler.build_preprocessing_pipeline(with_target=True, allow_empty_mf=allow_empty_mf)
    return pipeline, preproc_handler


def build_preprocessing_handler(collector_data, core_params, preprocessing_folder_context, preprocessing_params,
                                assertions=None, nan_support=None):
    preproc_handler = PredictionPreprocessingHandler.build(core_params,
                                                           preprocessing_params,
                                                           preprocessing_folder_context,
                                                           assertions,
                                                           nan_support=nan_support)
    preproc_handler.collector_data = collector_data
    return preproc_handler



def build_folder_contexts(preprocessing_set):
    preprocessing_set["preprocessing_folder_context"] = build_folder_context(preprocessing_set["run_folder"])
    preprocessing_set["split_folder_context"] = build_folder_context(preprocessing_set["split_folder"])
    for modeling_set in preprocessing_set["modelingSets"]:
        modeling_set["model_folder_context"] = build_folder_context(modeling_set["run_folder"])


def train_prediction_keras(core_params, preprocessing_set, split_desc, ml_overrides_params):
    start = unix_time_millis()
    log_nvidia_smi_if_use_gpu(core_params=core_params)

    default_diagnostics.register_keras_callbacks(core_params)

    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)
    build_folder_contexts(preprocessing_set)

    preprocessing_params = preprocessing_set["preprocessing_params"]
    modeling_sets = preprocessing_set["modelingSets"]
    split_folder_context = preprocessing_set["split_folder_context"]
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]

    logger.info("PPS is %s" % preprocessing_params)
    preprocessing_listener = ProgressListener()
    # Fill all the listeners ASAP to have correct progress data
    preprocessing_listener.add_future_steps(step_constants.PRED_KERAS_PREPROCESSING_STEPS)
    for modeling_set in modeling_sets:
        listener = preprocessing_listener.new_child(ModelStatusContext(modeling_set["model_folder_context"], start))
        listener.add_future_steps(step_constants.PRED_KERAS_TRAIN_STEPS)
        modeling_set["listener"] = listener

    # Called by the preprocessing pipeline to update the state
    # of each model and dump it to disk

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
        train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context)
        train_df_orig = train_df.copy()

        # Not implemented in the UI so far, so processor_fit_df will always be train_df
        preprocessor_fit_df = train_df
        need_subsampling = preprocessing_params["preprocessingFitSampleRatio"] < 1
        if need_subsampling:
            preprocessor_fit_df = preprocessor_fit_df.sample(frac=preprocessing_params["preprocessingFitSampleRatio"],
                                                             random_state=preprocessing_params["preprocessingFitSampleSeed"])

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
        test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
        collector = PredictionPreprocessingDataCollector(preprocessor_fit_df, preprocessing_params)
        collector_data = collector.build()

        # Tagging special features to take them into account only in special_preproc_handler/special_pipeline
        per_feature = preprocessing_params["per_feature"]
        tag_special_features(per_feature)

        pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, preprocessing_folder_context,
                                                               preprocessing_params, allow_empty_mf=True)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.KERAS_STEP_FIT_NORMAL_PREPROCESSING):
        # Retrieving transformed values to get the shape of all regular inputs, even if won't be
        # actually used, as each batch of data will be processed again
        transformed_normal = pipeline.fit_and_process(preprocessor_fit_df)
        preproc_handler.save_data()
        preproc_handler.report(pipeline)

    # TODO: REVIEW STATES OF TRAINING
    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
        pass

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
        pass

    preprocessing_listener.save_status()
    preprocessing_end = unix_time_millis()

    for modeling_set in modeling_sets:
        model_start = unix_time_millis()

        # Settings env variable that may be accessed in user defined code
        remoterun.set_dku_env_var_and_sys_env_var(doctor_constants.DKU_CURRENT_ANALYSIS_ID, modeling_set["fullId"]["taskLoc"]["analysisId"])
        remoterun.set_dku_env_var_and_sys_env_var(doctor_constants.DKU_CURRENT_MLTASK_ID, modeling_set["fullId"]["taskLoc"]["mlTaskId"])

        prediction_train_model_keras(transformed_normal, train_df_orig, test_df, pipeline, modeling_set["modelingParams"],
                                     core_params, per_feature, modeling_set["model_folder_context"], modeling_set["listener"],
                                     preproc_handler.target_map, pipeline.generated_features_mapping)

        end = unix_time_millis()
        listeners_json = preprocessing_listener.merge(modeling_set["listener"])
        utils.write_done_traininfo(modeling_set["model_folder_context"], start, model_start, end,
                                   listeners_json,
                                   end_preprocessing_time=preprocessing_end)

    return "ok"


def train_clustering_models_nosave(
        core_params,
        split_desc,
        preprocessing_set):
    """Regular (mode 1) train:
      - Non streamed single split + fit preprocess on train + preprocess test
      - Fit N models sequentially
         - Fit
         - Save clf
         - Compute and save clf performance
         - Score, save scored test set + scored performance
    """
    build_folder_contexts(preprocessing_set)
    split_folder_context = preprocessing_set["split_folder_context"]
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]
    start = unix_time_millis()
    default_diagnostics.register_clustering_callbacks(core_params)

    modeling_sets = preprocessing_set["modelingSets"]
    preprocessing_listener = ProgressListener()

    # Fill all the listeners ASAP to have correct progress data
    preprocessing_listener.add_future_steps(step_constants.CLUSTERING_REGULAR_PREPROCESSING_STEPS)
    for modeling_set in modeling_sets:
        listener = preprocessing_listener.new_child(ModelStatusContext(modeling_set["model_folder_context"], start))
        listener.add_future_steps(step_constants.ALL_CLUSTERING_TRAIN_STEPS)
        modeling_set["listener"] = listener

    logger.info("START TRAIN :" + preprocessing_set["description"])
    preprocessing_params = preprocessing_set["preprocessing_params"]

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_SRC):
        source_df = df_from_split_desc(split_desc, "full", split_folder_context, preprocessing_params["per_feature"])
        diagnostics.on_load_train_dataset_end(df=source_df)
        logger.info("Loaded source df: shape=(%d,%d)" % source_df.shape)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
        collector = ClusteringPreprocessingDataCollector(source_df, preprocessing_params)
        collector_data = collector.build()

        preproc_handler = ClusteringPreprocessingHandler({},
                                                        preprocessing_set["preprocessing_params"],
                                                        preprocessing_folder_context)

        preproc_handler.collector_data = collector_data
        pipeline = preproc_handler.build_preprocessing_pipeline()

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_SRC):
        source_df_index = source_df.index.copy()
        # TODO: fit_and_process should take an update_fn argument
        transformed_source = pipeline.fit_and_process(source_df)
        # Saves fitted resources and collector data
        preproc_handler.save_data()
        # Report on work
        report = {}
        pipeline.report_fit(report, {})

        preprocessing_folder_context.write_json("preprocessing_report.json", report)
        diagnostics.on_preprocess_train_dataset_end(multiframe=transformed_source["TRAIN"])

    preprocessing_listener.save_status()

    preprocessing_end = unix_time_millis()

    for modeling_set in modeling_sets:
        model_start = unix_time_millis()
        modeling_set["listener"].context = ModelStatusContext(modeling_set["model_folder_context"], start)

        clustering_train_score_save(transformed_source, source_df_index,
                                    preprocessing_set["preprocessing_params"],
                                    modeling_set["modelingParams"],
                                    modeling_set["model_folder_context"],
                                    modeling_set["listener"],
                                    pipeline)

        end = unix_time_millis()

        listeners_json = preprocessing_listener.merge(modeling_set["listener"])
        utils.write_done_traininfo(modeling_set["model_folder_context"], start, model_start, end,
                                   listeners_json,
                                   end_preprocessing_time=preprocessing_end)

    return "ok"


def clustering_rescore(
        split_desc,
        preprocessing_folder,
        model_folder,
        split_folder):

    preprocessing_folder_context = build_folder_context(preprocessing_folder)
    model_folder_context = build_folder_context(model_folder)
    split_folder_context = build_folder_context(split_folder)
    preprocessing_params = preprocessing_folder_context.read_json("rpreprocessing_params.json")
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    user_meta = model_folder_context.read_json("user_meta.json")

    split_desc = dkujson.loads(split_desc)
    source_df = df_from_split_desc(split_desc, "full", split_folder_context, preprocessing_params["per_feature"])
    logger.info("Loaded source df: shape=(%d,%d)" % source_df.shape)

    collector_data = preprocessing_folder_context.read_json("collector_data.json")

    preproc_handler = ClusteringPreprocessingHandler({}, preprocessing_params, preprocessing_folder_context)
    preproc_handler.collector_data = collector_data
    pipeline = preproc_handler.build_preprocessing_pipeline()

    source_df_index = source_df.index.copy()
    transformed_source = pipeline.fit_and_process(source_df)

    logger.info("Loading the clustering model")

    clusterer = load_model_from_folder(model_folder_context, is_prediction=False)

    try:
        logger.info("Post-processing the model")
        clusterer.post_process(user_meta)
    except AttributeError:
        pass

    train_np, is_sparse = prepare_multiframe(transformed_source["TRAIN"], modeling_params)
    from dataiku.doctor.clustering.anomaly_detection import DkuIsolationForest
    if isinstance(clusterer, DkuIsolationForest):
        cluster_labels, anomaly_scores = clusterer.predict_with_anomaly_score(train_np)
    else:
        cluster_labels = clusterer.predict(train_np)
        anomaly_scores = None

    logger.info("Rescoring the clustering model")
    ClusteringModelScorer(clusterer, transformed_source, source_df_index, cluster_labels, preprocessing_params, modeling_params,
                          pipeline, model_folder_context, anomaly_scores=anomaly_scores).score()

    return "ok"


def create_ensemble(split_desc, core_params, model_folder, preprocessing_folder, split_folder, model_folders, preprocessing_folders):
    start = unix_time_millis()
    model_folder_context = build_folder_context(model_folder)
    preprocessing_folder_context = build_folder_context(preprocessing_folder)
    split_folder_context = build_folder_context(split_folder)

    listener = ProgressListener(context=ModelStatusContext(model_folder_context, start))
    listener.add_future_steps(step_constants.ENSEMBLE_STEPS)

    split_desc = dkujson.loads(split_desc)
    core_params = dkujson.loads(core_params)

    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)

    # Completely disable diagnostics for ensemble models, as user already diagnostics from previous training
    diagnostics.disable()

    weight_method = core_params.get("weight", {}).get("weightMethod", None)
    with_sample_weight = weight_method in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    # TODO: update downstream
    with_class_weight = weight_method in {"CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    preprocessing_folder_contexts = [build_folder_context(pp_folder) for pp_folder in dkujson.loads(preprocessing_folders)]
    model_folder_contexts = [build_folder_context(m_folder) for m_folder in dkujson.loads(model_folders)]
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    ensemble_params = modeling_params["ensemble_params"]
    logger.info("creating ensemble")
    with listener.push_step(step_constants.ProcessingStep.STEP_ENSEMBLING):
        from dataiku.doctor.prediction.ensembles import ensemble_from_fitted
        train = df_from_split_desc(split_desc, "train", split_folder_context,
                                   ensemble_params["preprocessing_params"][0]["per_feature"],
                                   prediction_type)
        iperf = {
            "modelInputNRows" : train.shape[0], #todo : not the right count as may have dropped ...
            "modelInputNCols" : -1, # makes no sense for an ensemble as may have different preprocessings
            "modelInputIsSparse" : False
        }
        model_folder_context.write_json("iperf.json", iperf)
        clf = ensemble_from_fitted(core_params, ensemble_params, preprocessing_folder_contexts, model_folder_contexts, train, with_sample_weight, with_class_weight)

    logger.info("saving model")
    with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
        to_pkl(clf, model_folder_context)

    logger.info("scoring model")
    with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
        test = df_from_split_desc(split_desc, "test", split_folder_context, ensemble_params["preprocessing_params"][0]["per_feature"], prediction_type)
        # this is annoying, but we have to use one of the previous preprocessings in order to get the target
        prep_folder_context = preprocessing_folder_contexts[0]
        rppp = prep_folder_context.read_json("rpreprocessing_params.json")
        collector_data = prep_folder_context.read_json("collector_data.json")
        nan_support = PredictionAlgorithmNaNSupport(ensemble_params["modeling_params"][0], rppp, source_dss_version=get_source_dss_version(model_folder_contexts[0]))
        preprocessing_handler = PredictionPreprocessingHandler.build(core_params, rppp, prep_folder_context, nan_support=nan_support)
        preprocessing_handler.collector_data = collector_data
        pipe = preprocessing_handler.build_preprocessing_pipeline(with_target=True)
        transformed_test = pipe.process(test)

        y = transformed_test["target"]

        if with_sample_weight:
            sample_weight = transformed_test["weight"]
        else:
            sample_weight = None

        # Now that the CLF with scorable pipelines has been saved, set it in "pipelines with target" mode
        # to be able to compute metrics
        clf.set_with_target_pipelines_mode(True)

        pred = clf.predict(test)
        probas = None if prediction_type == doctor_constants.REGRESSION else clf.predict_proba(test)

        target_map = None if prediction_type == doctor_constants.REGRESSION else \
            {t["sourceValue"]: t["mappedValue"] for t in ensemble_params["preprocessing_params"][0]["target_remapping"]}

        if prediction_type == doctor_constants.REGRESSION:
            prediction_result = PredictionResult(pred)
            scorer = RegressionModelScorer(modeling_params, prediction_result, y, model_folder_context, test_unprocessed=transformed_test['UNPROCESSED'],
                                           test_X=transformed_test['TRAIN'], test_df_index=test.index.copy(), test_sample_weight=sample_weight)
        elif prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            decisions_and_cuts = DecisionsAndCuts.from_probas(probas, target_map)
            scorer = BinaryClassificationModelScorer(modeling_params, model_folder_context, decisions_and_cuts,
                                                     y, target_map, test_unprocessed=transformed_test['UNPROCESSED'], test_X=transformed_test['TRAIN'],
                                                     test_df_index=test.index.copy(), test_sample_weight=sample_weight)
        else:
            prediction_result = ClassificationPredictionResult(target_map, probas=probas, unmapped_preds=pred)
            scorer = MulticlassModelScorer(modeling_params, model_folder_context, prediction_result, y.astype(int), target_map,
                                           test_unprocessed=transformed_test['UNPROCESSED'], test_X=transformed_test['TRAIN'],
                                           test_df_index=test.index.copy(), test_sample_weight=sample_weight)
        scorer.score()
        scorer.save()

    listener.save_status()
    end = unix_time_millis()
    model_folder_context.write_json("actual_params.json", {"resolved": modeling_params})
    preprocessing_folder_context.write_json("preprocessing_report.json", {})
    utils.write_done_traininfo(model_folder_context, start, end, end, listener.to_jsonifiable(), end_preprocessing_time=start)

    return "ok"


def compute_pdp(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi, modellike_folder=None,
                computation_parameters=None, postcompute_folder=None, train_split_desc=None,
                train_split_folder=None):
    pdp.compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, modellike_folder,
                computation_parameters, fmi)
    return "ok"


def compute_subpopulation(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi, modellike_folder=None,
                          computation_parameters=None, postcompute_folder=None):
    diagnostics.disable()
    return subpopulation.compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder,
                                 computation_parameters, postcompute_folder, fmi)


# Model-less subpopulation is for Subpopulation computation inside a Model Evaluation. They should be able to be run again even if the SM had been deleted.
# However, because of custom preprocessings, we faced issues and added back the folder. Should be fixed with sc-101989
def compute_modelless_subpopulation(job_id, model_evaluation, features, modelevaluation_folder,
                                    resolved_preprocessing_params, iperf, computation_parameters=None, model_folder=None):
    diagnostics.disable()
    return subpopulation.compute_modelless(job_id, model_evaluation, features, modelevaluation_folder, model_folder,
                                           iperf, resolved_preprocessing_params, computation_parameters)


def compute_global_explanations(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                train_split_desc=None, train_split_folder=None):
    global_explanations.compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                train_split_desc)
    return "ok"


def compute_individual_explanations(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                    modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                    train_split_desc=None, train_split_folder=None):
    individual_explanations.compute(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder,
                                    computation_parameters, fmi, postcompute_folder, train_split_desc,
                                    train_split_folder)
    return "ok"

def compute_information_criteria(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                            modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                            train_split_desc=None, train_split_folder=None):

    model_folder_context = build_folder_context(model_folder)
    clf = from_pkl(model_folder_context)
    iperf = model_folder_context.read_json("iperf.json")
    iperf["informationCriteria"] = clf.get_information_criteria()
    model_folder_context.write_json("iperf.json", iperf)
    return "ok"


def compute_timeseries_residuals(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                 modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                 train_split_desc=None, train_split_folder=None):

    model_folder_context = build_folder_context(model_folder)
    preprocessing_folder_context = build_folder_context(preprocessing_folder)
    preprocessing_params = preprocessing_folder_context.read_json("rpreprocessing_params.json")
    listener = ProgressListener()
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    algorithm = TimeseriesForecastingAlgorithm.build(modeling_params["algorithm"])
    full_timeseries_preprocessing = TimeseriesPreprocessing(preprocessing_folder_context, core_params, preprocessing_params, modeling_params, listener, algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_external_features())
    split_folder_context = build_folder_context(split_folder)
    full_df = load_train_set(core_params, preprocessing_params, split_desc, "full",
                   split_folder_context, use_diagnostics=False)

    resolved_params = model_folder_context.read_json("actual_params.json")["resolved"]
    optimized_modeling_params = regridify_optimized_params(resolved_params, modeling_params)
    full_df = resample_for_training(
        full_df,
        split_desc["schema"],
        preprocessing_params[doctor_constants.TIMESERIES_SAMPLING],
        core_params,
        preprocessing_params,
        algorithm.EXTERNAL_FEATURES_COMPATIBILITY.supports_past_only_external_features(),
        modeling_params["metrics"]["evaluationMetric"] == MAPE.name,
    )

    # Windowing is performed once and for all on the full dataframe
    windows_list = get_windows_list(preprocessing_params) if modeling_params.get("isShiftWindowsCompatible", False) else []
    full_df = add_rolling_windows_for_training(full_df, core_params, windows_list, preprocessing_params, preprocessing_folder_context)

    transformed_full_df = full_timeseries_preprocessing.fit_and_process(
        full_df,
        step_constants.ProcessingStep.STEP_PREPROCESS_FULL,
        algorithm.ONE_MODEL_FOR_MULTIPLE_TS,
        save_data=False,
    )
    clf = from_pkl(model_folder_context)
    min_size_for_scoring = algorithm.get_min_size_for_scoring(
        algorithm.get_actual_params(optimized_modeling_params, clf, fit_params=None)['resolved'],
        preprocessing_params,
        core_params[doctor_constants.PREDICTION_LENGTH])
    residuals_model_folder_context = model_folder_context.get_subfolder_context("residuals")
    residuals_model_folder_context.create_if_not_exist()
    clf.compute_residuals(transformed_full_df, min_size_for_scoring, residuals_model_folder_context, PercentageProgress(job_id))
    return "ok"

def compute_timeseries_permutation_importance_computation(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                                          modellike_folder=None, computation_parameters=None, postcompute_folder=None,
                                                          train_split_desc=None, train_split_folder=None):
    return compute_post_train_permutation_importance(job_id, split_desc, core_params, preprocessing_folder, model_folder, split_folder, fmi,
                                                     modellike_folder, computation_parameters, postcompute_folder,
                                                     train_split_desc, train_split_folder)


# This is weird, right ?! This command will re-run the training, but from the training recipe, not the one from the LAB,
# that would have been one of the train_XXXX commands. Why do we do this convoluted approach ? Because the only way to
# run something on a saved model version at the moment, is by running a kernel, so it is much simpler to piggyback on
# this logic than implement a new way of running scripts (as the train recipe) on a saved model version.
# If we were to implement such a possibility in the future, please remove this approach.
def compute_retraining(exec_folder, operation_mode):
    from dataiku.doctor.prediction.reg_train_recipe import main
    main(exec_folder, operation_mode)
    return "ok"


def train_prediction_deephub(core_params, preprocessing_set, split_desc, ml_overrides_params):
    from dataiku.doctor.deephub.launch_training import launch_training
    from dataiku.doctor.deephub.deephub_params import DeepHubTrainingParams

    preprocessing_folder = preprocessing_set["run_folder"]
    split_folder = preprocessing_set["split_folder"]
    modeling_set = preprocessing_set["modelingSets"][0]  # we know for sure that there is only one model for deephub
    model_folder = modeling_set["run_folder"]
    model_folder_context = build_folder_context(model_folder)
    params = DeepHubTrainingParams(core_params, modeling_set["modelingParams"], preprocessing_set["preprocessing_params"],
                                   split_desc, preprocessing_folder, model_folder_context, split_folder,
                                   preprocessing_set["tmp_folder"])

    launch_training(params)
    return "ok"


def train_prediction_timeseries(core_params, preprocessing_set, split_desc, ml_overrides_params):
    from dataiku.doctor.timeseries.train.launch_training import launch_training

    modeling_sets = preprocessing_set["modelingSets"]
    preprocessing_params = preprocessing_set["preprocessing_params"]
    resampling_params = preprocessing_params[doctor_constants.TIMESERIES_SAMPLING]
    build_folder_contexts(preprocessing_set)
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]
    split_folder_context = preprocessing_set["split_folder_context"]

    launch_training(core_params, modeling_sets, preprocessing_params, resampling_params, preprocessing_folder_context, split_folder_context, split_desc)

    return "ok"


def train_causal_prediction(core_params, preprocessing_set, split_desc, ml_overrides_params):
    from dataiku.doctor.causal.train.launch_training import launch_training as launch_training_causal

    modeling_sets = preprocessing_set["modelingSets"]
    preprocessing_params = preprocessing_set["preprocessing_params"]
    build_folder_contexts(preprocessing_set)
    preprocessing_folder_context = preprocessing_set["preprocessing_folder_context"]
    split_folder_context = preprocessing_set["split_folder_context"]

    launch_training_causal(core_params, modeling_sets, preprocessing_params, preprocessing_folder_context, split_folder_context, split_desc)

    return "ok"


def ping():
    return "pong"
