# encoding: utf-8
"""
Execute a prediction training recipe in PyRegular mode
Must be called in a Flow environment
"""
import logging
import sys

from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import safe_unicode_str
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.doctor import step_constants
from dataiku.doctor import utils
from dataiku.doctor.commands import build_pipeline_and_handler
from dataiku.doctor.deep_learning.keras_utils import tag_special_features
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticsScoringResults
from dataiku.doctor.prediction.classification_fit import classification_fit
from dataiku.doctor.prediction.classification_scoring import ClassificationModelIntrinsicScorer
from dataiku.doctor.prediction.classification_scoring import binary_classification_scorer_with_valid
from dataiku.doctor.prediction.classification_scoring import \
    compute_assertions_and_overrides_for_classification_from_clf
from dataiku.doctor.prediction.classification_scoring import multiclass_scorer_with_valid
from dataiku.doctor.prediction.common import check_classical_prediction_type
from dataiku.doctor.prediction.common import get_monotonic_cst
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.prediction.common import regridify_optimized_params
from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
from dataiku.doctor.prediction.overrides.ml_overrides_params import ml_overrides_params_from_model_folder
from dataiku.doctor.prediction.prediction_model_serialization import ModelSerializer
from dataiku.doctor.prediction.regression_fit import regression_fit_single
from dataiku.doctor.prediction.regression_scoring import RegressionModelIntrinsicScorer
from dataiku.doctor.prediction.regression_scoring import compute_assertions_and_overrides_for_regression_from_clf
from dataiku.doctor.prediction.regression_scoring import regression_scorer_with_valid
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.prediction_entrypoints import prediction_train_model_keras
from dataiku.doctor.prediction_entrypoints import prediction_train_model_kfold
from dataiku.doctor.prediction_entrypoints import prediction_train_score_save
from dataiku.doctor.prediction_entrypoints import prediction_train_score_save_ensemble
from dataiku.doctor.preprocessing_collector import ClusteringPreprocessingDataCollector
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import check_train_test_order
from dataiku.doctor.utils.split import df_from_split_desc
from dataiku.doctor.utils.split import load_test_set
from dataiku.doctor.utils.split import load_train_set
from dataiku.doctor.utils.split import sort_dataframe

logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')


def main(exec_folder, operation_mode):
    """The whole execution of the saved model train takes place in a single folder ?"""
    start = unix_time_millis()
    start_train = start

    exec_folder_context = build_folder_context(exec_folder)
    split_folder_context = exec_folder_context.get_subfolder_context("split")
    listener = ProgressListener(context=ModelStatusContext(exec_folder_context, start))
    split_desc = split_folder_context.read_json("split.json")
    core_params = exec_folder_context.read_json("core_params.json")

    model_type = core_params["taskType"]
    preprocessing_params = exec_folder_context.read_json("rpreprocessing_params.json")
    weight_method = core_params.get("weight", {}).get("weightMethod", None)
    with_sample_weight = weight_method in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    with_class_weight = weight_method in {"CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    modeling_params = exec_folder_context.read_json("rmodeling_params.json")
    ml_overrides_params = ml_overrides_params_from_model_folder(exec_folder_context)
    sort = core_params.get("time", {}).get("enabled", False)
    time_variable = core_params.get("time", {}).get("timeVariable")
    ascending = core_params.get("time", {}).get("ascending")
    if sort and time_variable is None:
        raise ValueError("Time ordering is enabled but no time variable is specified")

    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)

    calibration_method = core_params.get("calibration", {}).get("calibrationMethod")
    calibrate_proba = calibration_method in [doctor_constants.SIGMOID, doctor_constants.ISOTONIC]
    calibration_ratio = core_params.get("calibration", {}).get("calibrationDataRatio", doctor_constants.DEFAULT_CALIBRATION_DATA_RATIO)
    calibrate_on_test = core_params.get("calibration", {}).get("calibrateOnTestSet", True)

    nan_support = PredictionAlgorithmNaNSupport(modeling_params, preprocessing_params)

    # For KERAS backend, need to tag special features, because they are only processed with process function,
    # not fit_and_process
    if modeling_params["algorithm"] == "KERAS_CODE":
        tag_special_features(preprocessing_params['per_feature'])

    # Only compute assertions if not Keras or Ensemble
    assertions_params_file = "rassertions.json"
    assertions_list = None
    if exec_folder_context.isfile(assertions_params_file) and modeling_params["algorithm"] not in {"KERAS_CODE", "PYTHON_ENSEMBLE" }:
        assertions_list = exec_folder_context.read_json(assertions_params_file).get("assertions", None)

    if modeling_params["algorithm"] == "PYTHON_ENSEMBLE":
        diagnostics.disable()
    elif modeling_params["algorithm"] == "KERAS_CODE":
        default_diagnostics.register_keras_callbacks(core_params)
    else:
        default_diagnostics.register_prediction_callbacks(core_params)

    log_nvidia_smi_if_use_gpu(core_params=core_params)

    def do_full_fit_and_save():
        """Fit on 100% and save the clf and out params"""

        assert modeling_params["algorithm"] != "KERAS_CODE", "Training on full data is not supported for KERAS backend"

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
            full_df = load_train_set(core_params, preprocessing_params, split_desc, "full", split_folder_context,
                                     assertions=assertions_list)

        with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
            collector = ClusteringPreprocessingDataCollector(full_df, preprocessing_params)
            collector_data = collector.build()

            pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, exec_folder_context,
                                                                   preprocessing_params,
                                                                   assertions=assertions_list,
                                                                   nan_support=nan_support)

            # TODO
            if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                target_map = preproc_handler.target_map
            else:
                target_map = None

        with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_FULL):
            transformed_full = pipeline.fit_and_process(full_df)
            diagnostics.on_kfold_step_preprocess_global_end(transformed_full["TRAIN"])
            monotonic_cst = get_monotonic_cst(preprocessing_params, transformed_full["TRAIN"])

            if with_sample_weight:
                assert transformed_full["weight"].values.min() > 0, "Sample weights must be positive"

            preproc_handler.save_data()
            preproc_handler.report(pipeline)

        with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
            if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
                (clf, actual_params, prepared_X, iipd) = classification_fit(modeling_params,
                                                                            core_params,
                                                                            transformed_full,
                                                                            model_folder_context=exec_folder_context,
                                                                            target_map=target_map,
                                                                            with_sample_weight=with_sample_weight,
                                                                            with_class_weight=with_class_weight,
                                                                            calibration_method=calibration_method,
                                                                            calibration_ratio=calibration_ratio,
                                                                            calibrate_on_test=False,  # Training is done on transformed_full, so no test set is left
                                                                            monotonic_cst=monotonic_cst)
            elif prediction_type == doctor_constants.REGRESSION:
                (clf, actual_params, prepared_X, iipd) = regression_fit_single(modeling_params, core_params, transformed_full,
                                                                               model_folder_context=exec_folder_context,
                                                                               with_sample_weight=with_sample_weight,
                                                                               monotonic_cst=monotonic_cst)
            diagnostics.on_fitting_end(features=transformed_full["TRAIN"].columns(), clf=clf, prediction_type=prediction_type, train_target=transformed_full["target"])
            model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"],
                                        preprocessing_params, ml_overrides_params)

        with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
            train_X = transformed_full["TRAIN"]
            ModelSerializer.build(model, exec_folder_context, train_X.columns(), calibrate_proba).serialize()
            exec_folder_context.write_json("actual_params.json", actual_params)
        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            full_assertions_metrics = None
            full_overrides_metrics = None
            train_y = transformed_full["target"]
            if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                    ClassificationModelIntrinsicScorer(modeling_params, clf, train_X, train_y, target_map,
                                                       pipeline, exec_folder_context, prepared_X, iipd,
                                                       with_sample_weight, calibrate_proba).score()
                elif prediction_type == doctor_constants.MULTICLASS:
                    ClassificationModelIntrinsicScorer(modeling_params, clf, train_X, train_y, target_map,
                                                       pipeline, exec_folder_context, prepared_X, iipd,
                                                       with_sample_weight, calibrate_proba).score()

                if "assertions" in transformed_full or ml_overrides_params is not None:
                    full_assertions_metrics, full_overrides_metrics = compute_assertions_and_overrides_for_classification_from_clf(clf,
                                                                                                                                   model_type,
                                                                                                                                   modeling_params,
                                                                                                                                   prediction_type,
                                                                                                                                   preprocessing_params,
                                                                                                                                   target_map,
                                                                                                                                   transformed_full,
                                                                                                                                   ml_overrides_params)
            elif prediction_type == doctor_constants.REGRESSION:
                RegressionModelIntrinsicScorer(modeling_params, clf, train_X, train_y, pipeline, exec_folder_context,
                                               prepared_X, iipd, with_sample_weight).score()
                if "assertions" in transformed_full or ml_overrides_params is not None:
                    full_assertions_metrics, full_overrides_metrics = compute_assertions_and_overrides_for_regression_from_clf(clf, model_type, modeling_params, transformed_full, ml_overrides_params)

        return actual_params, full_assertions_metrics, prepared_X, transformed_full, full_overrides_metrics

    if operation_mode == "TRAIN_SPLITTED_ONLY":

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
            train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context)

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
            test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context, assertions=assertions_list)

        with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
            collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
            collector_data = collector.build()
            pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, exec_folder_context,
                                                                   preprocessing_params,
                                                                   allow_empty_mf=modeling_params[
                                                                                      "algorithm"] == "KERAS_CODE",
                                                                   assertions=assertions_list,
                                                                   nan_support=nan_support)

            if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                target_map = preproc_handler.target_map
            else:
                target_map = None

        with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
            preprocessor_fit_df = train_df

            if sort:
                check_train_test_order(train_df, test_df, time_variable, ascending)
                if needs_hyperparameter_search(modeling_params):
                    logger.info(u"Hyperparameter search enabled: checking that the train set is sorted by '{}'".format(safe_unicode_str(time_variable)))
                    sort_dataframe(train_df, time_variable, ascending)

            # For KERAS backend, we might need to take a subsample of the input_df to prevent from memory errors
            if modeling_params["algorithm"] == "KERAS_CODE":
                train_df_orig = train_df.copy()
                need_subsampling = preprocessing_params["preprocessingFitSampleRatio"] < 1
                if need_subsampling:
                    preprocessor_fit_df = preprocessor_fit_df.sample(
                        frac=preprocessing_params["preprocessingFitSampleRatio"],
                        random_state=preprocessing_params["preprocessingFitSampleSeed"])

            transformed_train = pipeline.fit_and_process(preprocessor_fit_df)

            diagnostics.on_preprocess_train_dataset_end(transformed_train["TRAIN"])

            if with_sample_weight:
                assert transformed_train["weight"].values.min() > 0, "Sample weights must be positive"

            preproc_handler.save_data()
            preproc_handler.report(pipeline)

        # For KERAS backend, cannot process test directly, because my have special features that may not
        # hold in memory
        if modeling_params["algorithm"] != "KERAS_CODE":
            with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
                test_df_index = test_df.index.copy()
                transformed_test = pipeline.process(test_df)
                diagnostics.on_preprocess_test_dataset_end(transformed_test["TRAIN"])
                if with_sample_weight:
                    assert transformed_test["weight"].values.min() > 0, "Sample weights must be positive"

        if modeling_params["algorithm"] == "PYTHON_ENSEMBLE":
            prediction_train_score_save_ensemble(train_df, test_df,
                                                 core_params, modeling_params, exec_folder_context, listener,
                                                 target_map, pipeline, with_sample_weight)
        elif modeling_params["algorithm"] == "KERAS_CODE":
            old_context = listener.context
            listener.context = ModelStatusContext(exec_folder_context, start)

            prediction_train_model_keras(transformed_train, train_df_orig, test_df, pipeline, modeling_params,
                                         core_params, preprocessing_params["per_feature"], exec_folder_context, listener,
                                         preproc_handler.target_map,
                                         pipeline.generated_features_mapping)
            listener.context = old_context
        else:
            prediction_train_score_save(transformed_train, transformed_test, test_df_index, core_params, split_desc,
                                        modeling_params, exec_folder_context, exec_folder_context, split_folder_context,
                                        listener, target_map, pipeline,
                                        preprocessing_params, ml_overrides_params)

    elif operation_mode == "TRAIN_FULL_ONLY":
        # Not yet functional ...
        do_full_fit_and_save()

    elif operation_mode == "TRAIN_KFOLD":
        actual_params, assertions_metrics, prepared_full, transformed_full, overrides_metrics = do_full_fit_and_save()

        monotonic_cst = get_monotonic_cst(preprocessing_params, transformed_full["TRAIN"])
        full_df_clean = df_from_split_desc(split_desc, "full", split_folder_context, preprocessing_params["per_feature"], prediction_type)
        optimized_params_grid = regridify_optimized_params(actual_params["resolved"], modeling_params)
        prediction_train_model_kfold(full_df_clean,
                                     core_params, split_desc, preprocessing_params, modeling_params,
                                     optimized_params_grid,
                                     exec_folder_context, exec_folder_context, split_folder_context, listener, with_sample_weight,
                                     with_class_weight,
                                     transformed_full,
                                     assertions_metrics=assertions_metrics,
                                     ml_overrides_params=ml_overrides_params,
                                     overrides_metrics=overrides_metrics,
                                     monotonic_cst=monotonic_cst)

    elif operation_mode == "TRAIN_SPLITTED_AND_FULL":
        actual_params, assertions_metrics, _, _, overrides_metrics = do_full_fit_and_save()
        # Do the split and scoring but don't save data
        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
            # no need to load assertions or compute diagnostics as they already have been computed on actual model
            # with full data
            train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context, use_diagnostics=False)

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
            # no need to load assertions or compute diagnostics as they already have been computed on actual model
            # with full data
            test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context, use_diagnostics=False)

        with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
            collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
            collector_data = collector.build()

            pipeline, preproc_handler = build_pipeline_and_handler(collector_data, core_params, exec_folder_context,
                                                                   preprocessing_params, nan_support=nan_support)

            if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                target_map = preproc_handler.target_map
            else:
                target_map = None

        with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
            preprocessor_fit_df = train_df
            transformed_train = pipeline.fit_and_process(preprocessor_fit_df)
            diagnostics.on_preprocess_train_dataset_end(transformed_train["TRAIN"])

        with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
            test_df_index = test_df.index.copy()
            transformed_test = pipeline.process(test_df)
            diagnostics.on_preprocess_test_dataset_end(transformed_test["TRAIN"])

        with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
            monotonic_cst = get_monotonic_cst(preprocessing_params, transformed_train["TRAIN"])
            optimized_params_grid = regridify_optimized_params(actual_params["resolved"], modeling_params)
            if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                clf, _, _, _ = classification_fit(optimized_params_grid,
                                                  core_params,
                                                  transformed_train,
                                                  transformed_test=transformed_test,
                                                  target_map=target_map,
                                                  with_sample_weight=with_sample_weight,
                                                  with_class_weight=with_class_weight,
                                                  calibration_method=calibration_method,
                                                  calibration_ratio=calibration_ratio,
                                                  calibrate_on_test=calibrate_on_test,
                                                  monotonic_cst=monotonic_cst)
            else:
                clf, _, _, _ = regression_fit_single(optimized_params_grid, core_params, transformed_train,
                                                     with_sample_weight=with_sample_weight,
                                                     monotonic_cst=monotonic_cst)

        model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"], preprocessing_params, ml_overrides_params)

        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            # We don't compute the intrinsic score nor serialize the model as we already did it for the
            # actual model, trained on the full dataset during `do_full_fit_and_save`.
            if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                scorer = binary_classification_scorer_with_valid(modeling_params, model, transformed_test, exec_folder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
            elif prediction_type == doctor_constants.MULTICLASS:
                scorer = multiclass_scorer_with_valid(modeling_params, model, transformed_test, exec_folder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
            elif prediction_type == doctor_constants.REGRESSION:
                scorer = regression_scorer_with_valid(modeling_params, model, transformed_test, exec_folder_context, test_df_index, with_sample_weight)
            scorer.score()
            scorer.save()
            diagnostics.on_scoring_end(scoring_results=DiagnosticsScoringResults.build_from_scorer(prediction_type, scorer),
                                       transformed_test=transformed_test, transformed_train=transformed_train, with_sample_weight=with_sample_weight)

            # Adding assertions/overrides metrics afterwards in order not to mess with existing code
            if overrides_metrics is not None or assertions_metrics is not None:
                perf = exec_folder_context.read_json("perf.json")
                if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                    if assertions_metrics is not None:
                        perf["perCutData"]["assertionsMetrics"] = [metrics.to_dict() for metrics in assertions_metrics]
                    if overrides_metrics is not None:
                        perf["perCutData"]["overridesMetrics"] = [metrics.to_dict() for metrics in overrides_metrics]
                elif prediction_type in {doctor_constants.MULTICLASS, doctor_constants.REGRESSION}:
                    if assertions_metrics is not None:
                        perf["metrics"]["assertionsMetrics"] = assertions_metrics.to_dict()
                    if overrides_metrics is not None:
                        perf["metrics"]["overridesMetrics"] = overrides_metrics.to_dict()
                exec_folder_context.write_json("perf.json", perf)
    else:
        raise NotImplementedError("Unknown value for operation_mode: {}".format(operation_mode))
    end = unix_time_millis()

    utils.write_done_traininfo(exec_folder_context, start, start_train, end, listener.to_jsonifiable())


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(sys.argv[1], sys.argv[2])
