# encoding: utf-8

"""
Execute an evaluation recipe in PyRegular mode
Must be called in a Flow environment
"""
import logging
import sys

import pandas as pd

from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.base.utils import safe_exception
from dataiku.core import debugging
from dataiku.core import dkujson as dkujson
from dataiku.core import doctor_constants
from dataiku.core import schema_handling
from dataiku.doctor.evaluation.base import EvaluateRecipe, compute_custom_evaluation_metrics_df
from dataiku.doctor.evaluation.base import add_evaluation_columns
from dataiku.doctor.evaluation.base import compute_metrics_df
from dataiku.doctor.evaluation.base import handle_percentiles_and_cond_outputs
from dataiku.doctor.evaluation.base import load_input_dataframe
from dataiku.doctor.evaluation.base import process_input_df_skip_predict
from dataiku.doctor.evaluation.base import run_binary_scoring
from dataiku.doctor.evaluation.base import run_multiclass_scoring
from dataiku.doctor.evaluation.base import run_regression_scoring
from dataiku.doctor.evaluation.base import sample_and_store_dataframe
from dataiku.doctor.prediction.classification_scoring import binary_classification_predict
from dataiku.doctor.prediction.classification_scoring import compute_assertions_for_decision
from dataiku.doctor.prediction.classification_scoring import multiclass_predict
from dataiku.doctor.prediction.classification_scoring import save_classification_statistics
from dataiku.doctor.prediction.overrides.ml_overrides_params import OVERRIDE_INFO_COL
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverridesResultsMixin
from dataiku.doctor.prediction.reg_scoring_recipe import load_model_partitions
from dataiku.doctor.prediction.regression_scoring import compute_assertions_for_regression
from dataiku.doctor.prediction.regression_scoring import regression_predict
from dataiku.doctor.prediction.regression_scoring import save_regression_statistics
from dataiku.doctor.preprocessing.assertions import MLAssertion
from dataiku.doctor.preprocessing.assertions import MLAssertions
from dataiku.doctor.preprocessing.assertions import cast_assertions_masks_bool
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.scoring_recipe_utils import generate_part_df_and_model_params
from dataiku.doctor.utils.scoring_recipe_utils import get_input_parameters
from dataikuscoring.utils.prediction_result import AbstractPredictionResult

logger = logging.getLogger(__name__)


class RegularEvaluateRecipe(EvaluateRecipe):

    def __init__(self, model_folder, input_dataset_smartname, managed_folder_smart_id, output_dataset_smartname, metrics_dataset_smartname,
                 recipe_desc, script, preparation_output_schema, cond_outputs=None, preprocessing_params=None,
                 evaluation_store_folder=None, evaluation_dataset_type=None, api_node_logs_config=None,
                 diagnostics_folder=None, fmi=None):
        """
        :param str model_folder: Path to the folder of the saved model.
        :param str input_dataset_smartname: The smart name for the input dataset to evaluate.
        :param str managed_folder_smart_id: The smart ID for the managed folder storing images.
        :param str output_dataset_smartname: The smart name for the output dataset.
        :param str metrics_dataset_smartname: The smart name for the metrics dataset.
        :param dict recipe_desc: The recipe's parameters (EvaluationRecipePayloadParams).
        :param dict script: The script (SerializedShakerScript).
        :param dict preparation_output_schema: The output schema of the dataset (Schema).
        :param dict cond_outputs: Optional. The conditional outputs (List<ConditionalOutput>).
        :param dict preprocessing_params: The preprocessing parameters (ResolvedPreprocessingParams).
        :param str evaluation_store_folder: Optional path to the Model Evaluation Store folder.
        :param str evaluation_dataset_type: The type of dataset being evaluated (e.g., 'CLASSIC').
        :param dict api_node_logs_config: Configuration for API node logging.
        :param str diagnostics_folder: Optional path to a folder for storing diagnostic files.
        :param str fmi: Optional Full Model ID (FMI).
        """
        super(RegularEvaluateRecipe, self).__init__(model_folder, input_dataset_smartname, managed_folder_smart_id, output_dataset_smartname,
                                                    metrics_dataset_smartname, recipe_desc, script,
                                                    preparation_output_schema,
                                                    cond_outputs, preprocessing_params, evaluation_store_folder,
                                                    evaluation_dataset_type, api_node_logs_config,
                                                    diagnostics_folder, fmi)
        self.assertions = None
        self.overrides_metrics = None
        self.input_dataset = None
        self.core_params = None
        self.columns = None
        self.dtypes = None
        self.parse_date_columns = None
        self.with_sample_weight = None
        self.partition_dispatch = None
        self.partitions = None
        self.y = None
        self.unprocessed = None
        self.sample_weight = None
        self.modeling_params = None
        self.target_mapping = None

    def _fetch_input_dataset_and_model_params(self):
        self.input_dataset, self.core_params, self.feature_preproc, self.columns, self.dtypes, self.parse_date_columns = \
            get_input_parameters(self.model_folder_context, self.input_dataset_smartname,
                                 self.preparation_output_schema, self.script,
                                 self.managed_folder_smart_id)
        self.prediction_type = self.core_params['prediction_type']
        self.target_column_in_dataset = self.core_params.get("target_variable")
        self.model_target_column = self.core_params.get("target_variable")

        self.with_sample_weight = self.core_params.get("weight", {}).get("weightMethod") in \
                                  {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}

        recipe_gpu_config = get_gpu_config_from_recipe_desc(self.recipe_desc)
        self.partition_dispatch, self.partitions, _ = load_model_partitions(
            self.model_folder_context, self.core_params, recipe_gpu_config, self.fmi, for_eval=not self.dont_compute_performance
        )

        self.input_dataset.preparation_requested_output_schema = self.preparation_output_schema

        if self.recipe_desc.get('skipScoring', False):
            self.target_mapping = self.partitions["NP"][5].target_map

    def _get_input_df(self, parse_date_columns=None):
        input_df = load_input_dataframe(
            input_dataset=self.input_dataset,
            sampling=self.recipe_desc.get('selection', {"samplingMethod": "FULL"}),
            columns=self.columns,
            dtypes=self.dtypes,
            parse_date_columns=self.parse_date_columns,
        )

        return input_df

    def _get_sample_dfs(self, input_df):
        schema = {"columns": schema_handling.get_schema_from_df(input_df)}
        sample_input_df = sample_and_store_dataframe(self.model_evaluation_store_folder_context, input_df, schema, limit_sampling=self.recipe_desc.get('limitSampling', True))
        sample_input_df_copy_unnormalized = sample_input_df.copy()
        cast_assertions_masks_bool(sample_input_df)

        # make sure to not pass the MES folder, since we just want the scoring part here
        if self.recipe_desc.get('skipScoring', False):
            if not self.dont_compute_performance:
                sample_input_df = sample_input_df.dropna(subset=[self.model_target_column])
            sample_pred_df = sample_input_df[[self.prediction_column] + self.proba_columns]
        else:
            sample_pred_df, _, _, _, _, _, _, _ = process_input_df(sample_input_df, self.partition_dispatch,
                                                                   self.core_params,
                                                                   self.partitions, self.with_sample_weight,
                                                                   self.recipe_desc,
                                                                   self.cond_outputs, None, True)

        # also remove ml assertions mask columns from the output
        clean_kept_columns = [c for c in sample_input_df_copy_unnormalized.columns if c not in sample_pred_df.columns
                              and not c.startswith(MLAssertion.ML_ASSERTION_MASK_PREFIX)]
        sample_output_df = pd.concat([sample_input_df_copy_unnormalized[clean_kept_columns], sample_pred_df], axis=1)

        return sample_input_df_copy_unnormalized, sample_output_df

    def _compute_output_and_pred_df(self, input_df, input_df_copy_unnormalized):
        if self.recipe_desc.get('skipScoring', False):
            model_folder_context, _, _, pipeline, self.modeling_params, _ = self.partitions["NP"]
            if self.prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                self.modeling_params["autoOptimizeThreshold"] = False
                self.modeling_params["forcedClassifierThreshold"] = self._get_used_threshold()
            pred_df, self.y, self.unprocessed, self.sample_weight, self.assertions = process_input_df_skip_predict(input_df,
                                                                                                                   model_folder_context,
                                                                                                                   pipeline,
                                                                                                                   self.modeling_params,
                                                                                                                   self.target_mapping,
                                                                                                                   self.prediction_type,
                                                                                                                   self.prediction_column,
                                                                                                                   self.proba_columns,
                                                                                                                   self.with_sample_weight,
                                                                                                                   self.dont_compute_performance,
                                                                                                                   self.recipe_desc,
                                                                                                                   self.cond_outputs,
                                                                                                                   self.model_evaluation_store_folder_context)

        else:
            pred_df, self.y, self.unprocessed, self.sample_weight, self.modeling_params, \
                self.target_mapping, self.assertions, self.overrides_metrics = process_input_df(input_df,
                                                                                                self.partition_dispatch,
                                                                                                self.core_params,
                                                                                                self.partitions,
                                                                                                self.with_sample_weight,
                                                                                                self.recipe_desc,
                                                                                                self.cond_outputs,
                                                                                                self.model_evaluation_store_folder_context,
                                                                                                self.dont_compute_performance)

        return self._get_output_from_pred(input_df_copy_unnormalized, pred_df), pred_df

    def _compute_metrics_df(self, output_df, pred_df):
        metrics_df = compute_metrics_df(self.core_params["prediction_type"], self.target_mapping, self.modeling_params,
                                        output_df, self.recipe_desc.get("metrics"),
                                        self.recipe_desc.get("customMetrics"), self.y, self.unprocessed,
                                        self.recipe_desc.get("outputProbabilities"), self.sample_weight,
                                        treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

        if self.recipe_desc.get("computeAssertions", False):
            assertions_metrics_df = compute_assertions_df(self.core_params["prediction_type"], pred_df, self.assertions,
                                                          self.target_mapping)
            metrics_df = pd.concat([metrics_df, assertions_metrics_df], axis=1)

        if self.overrides_metrics is not None:
            overrides_metrics_df = pd.DataFrame({"overridesMetrics": [dkujson.dumps(self.overrides_metrics)]})
            metrics_df = pd.concat([metrics_df, overrides_metrics_df], axis=1)

        return metrics_df

    def _compute_custom_evaluation_metrics_df(self, output_df, pred_df, unprocessed_input_df, ref_sample_df):
        metrics_df = compute_custom_evaluation_metrics_df(
            output_df,
            self.recipe_desc.get("customEvaluationMetrics"),
            self.prediction_type,
            self.target_mapping,
            self.y,
            unprocessed_input_df,
            ref_sample_df,
            self.recipe_desc.get("outputProbabilities"),
            self.sample_weight,
            treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

        return metrics_df

def process_input_df(input_df, partition_dispatch, core_params, partitions, with_sample_weight, recipe_desc,
                     cond_outputs, evaluation_store_folder_context, dont_compute_performance):
    part_dfs = {"pred": [], "target": [], "weight": [], "unprocessed": [], "assertions": [], "prediction_result": []}
    for part_df, part_params, _ in generate_part_df_and_model_params(input_df, partition_dispatch, core_params,
                                                                     partitions, raise_if_not_found=False):

        model_folder_context, preprocessing_params, model, pipeline, modeling_params, preprocessing_handler = part_params

        logger.info("Processing it")
        transformed = pipeline.process(part_df)
        logger.info("Predicting it")
        treat_perf_metrics_failure_as_error = recipe_desc.get("treatPerfMetricsFailureAsError", True)

        if core_params["prediction_type"] == doctor_constants.BINARY_CLASSIFICATION:

            # Computing threshold
            if recipe_desc["overrideModelSpecifiedThreshold"]:
                used_threshold = recipe_desc.get("forcedClassifierThreshold")
            else:
                used_threshold = model_folder_context.read_json("user_meta.json").get("activeClassifierThreshold")

            scoring_data = binary_classification_predict(
                model,
                pipeline,
                modeling_params,
                preprocessing_handler.target_map,
                used_threshold,
                part_df,
                output_probas=recipe_desc["outputProbabilities"],
                # For ensemble model, we need to indicate that we have target, so that a target-aware pipeline is
                # selected. See 0c87605 for more information
                ensemble_has_target=not dont_compute_performance)

            pred_df = scoring_data.pred_and_proba_df

            if evaluation_store_folder_context is not None and dont_compute_performance:  # Only save statistics if we don't run the scoring
                save_classification_statistics(scoring_data.preds_df.iloc[:, 0],
                                               evaluation_store_folder_context,
                                               probas=scoring_data.probas_df.values if scoring_data.probas_df is not None else None,
                                               sample_weight=None,
                                               target_map=preprocessing_handler.target_map)

            # Probability percentile & Conditional outputs
            handle_percentiles_and_cond_outputs(pred_df, recipe_desc, cond_outputs, model_folder_context,
                                                preprocessing_handler.target_map)

            if evaluation_store_folder_context is not None and not dont_compute_performance:
                modeling_params["autoOptimizeThreshold"] = False
                modeling_params["forcedClassifierThreshold"] = used_threshold
                run_binary_scoring(modeling_params,
                                   scoring_data.decisions_and_cuts,
                                   transformed["target"].astype(int),
                                   preprocessing_handler.target_map,
                                   transformed["weight"] if with_sample_weight else None,
                                   evaluation_store_folder_context,
                                   assertions=transformed.get("assertions", None),
                                   test_unprocessed=transformed["UNPROCESSED"],
                                   test_X=transformed["TRAIN"],
                                   treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)

        elif core_params["prediction_type"] == doctor_constants.MULTICLASS:
            scoring_data = multiclass_predict(model, pipeline, modeling_params, preprocessing_handler.target_map,
                                              part_df, output_probas=recipe_desc["outputProbabilities"],
                                              # For ensemble model, we need to indicate that we have target, so that a
                                              # target-aware pipeline is selected. See 0c87605 for more information.
                                              ensemble_has_target=not dont_compute_performance)

            pred_df = scoring_data.pred_and_proba_df
            if evaluation_store_folder_context is not None and dont_compute_performance:  # Only save statistics if we don't run the scoring
                save_classification_statistics(scoring_data.preds_df.iloc[:, 0],
                                               evaluation_store_folder_context,
                                               probas=scoring_data.probas_df.values if scoring_data.probas_df is not None else None,
                                               sample_weight=None,
                                               target_map=preprocessing_handler.target_map)

            if evaluation_store_folder_context is not None and not dont_compute_performance:
                run_multiclass_scoring(modeling_params, scoring_data.prediction_result,
                                       transformed["target"].astype(int),
                                       preprocessing_handler.target_map,
                                       transformed["weight"] if with_sample_weight else None,
                                       evaluation_store_folder_context,
                                       assertions=transformed.get("assertions", None),
                                       test_unprocessed=transformed["UNPROCESSED"],
                                       test_X=transformed["TRAIN"],
                                       treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)

        elif core_params["prediction_type"] == doctor_constants.REGRESSION:
            scoring_data = regression_predict(model, pipeline, modeling_params, part_df,
                                              # For ensemble model, we need to indicate that we have target, so that a
                                              # target-aware pipeline is selected. See 0c87605 for more information.
                                              ensemble_has_target=not dont_compute_performance)
            pred_df = scoring_data.preds_df

            if evaluation_store_folder_context is not None and dont_compute_performance:  # Only save statistics if we don't run the scoring
                save_regression_statistics(pred_df.iloc[:, 0], evaluation_store_folder_context)

            if evaluation_store_folder_context is not None and not dont_compute_performance:
                run_regression_scoring(modeling_params, scoring_data.prediction_result, transformed["target"],
                                       transformed["weight"] if with_sample_weight else None,
                                       evaluation_store_folder_context,
                                       assertions=transformed.get("assertions", None),
                                       test_unprocessed=transformed["UNPROCESSED"],
                                       test_X=transformed["TRAIN"],
                                       treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)
        else:
            raise ValueError("bad prediction type %s" % core_params["prediction_type"])

        part_dfs["pred"].append(pred_df)
        part_dfs["prediction_result"].append(scoring_data.prediction_result)
        if 'target' in transformed:
            part_dfs["target"].append(transformed["target"])
        part_dfs["unprocessed"].append(transformed["UNPROCESSED"])
        if with_sample_weight and not dont_compute_performance:
            part_dfs["weight"].append(transformed["weight"])

        if transformed.get("assertions", None) is not None:
            part_dfs["assertions"].append(transformed["assertions"])

    # Re-patch partitions together
    if partition_dispatch:
        if len(part_dfs["pred"]) > 0:
            pred_df = pd.concat(part_dfs["pred"], axis=0)
            if dont_compute_performance:
                y = None
                sample_weight = None
            else:
                y = pd.concat(part_dfs["target"], axis=0)
                sample_weight = pd.concat(part_dfs["weight"], axis=0) if with_sample_weight else None
            unprocessed = pd.concat(part_dfs["unprocessed"], axis=0)
            assertions = MLAssertions.concatenate_assertions_list(part_dfs["assertions"])
            prediction_result = AbstractPredictionResult.concat(part_dfs["prediction_result"])
        else:
            raise Exception("All partitions found in dataset are unknown to the model, cannot evaluate it")
    else:
        pred_df = part_dfs["pred"][0]
        if dont_compute_performance:
            y = None
            sample_weight = None
        else:
            y = part_dfs["target"][0]
            sample_weight = part_dfs["weight"][0] if with_sample_weight else None
        unprocessed = part_dfs["unprocessed"][0]
        assertions = part_dfs["assertions"][0] if len(part_dfs["assertions"]) > 0 else None
        prediction_result = part_dfs["prediction_result"][0]

    # add error information to pred_df
    target_mapping = {}
    if core_params["prediction_type"] in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
        target_mapping = {
            label: int(class_id)
            for label, class_id in preprocessing_handler.target_map.items()
        }
    if y is not None:
        pred_df = add_evaluation_columns(core_params["prediction_type"], pred_df, y, recipe_desc["outputs"],
                                         target_mapping)

    overrides_metrics = None
    if isinstance(prediction_result, OverridesResultsMixin):
        overrides_metrics = prediction_result.compute_and_return_overrides_metrics().to_dict()
        pred_df[OVERRIDE_INFO_COL] = prediction_result.compute_and_return_info_column()

    return pred_df, prediction_result.align_with_not_declined(y), prediction_result.align_with_not_declined(unprocessed), \
        prediction_result.align_with_not_declined(sample_weight), modeling_params, target_mapping, assertions, overrides_metrics


def compute_assertions_df(prediction_type, pred_df, assertions, target_map):
    if assertions is None:
        logger.info("No assertion provided. skipping computation")
        # return empty df for schema compatibility
        return pd.DataFrame(columns=["passingAssertionsRatio", "assertionsMetrics"])

    logger.info("Evaluating {} assertions".format(len(assertions)))
    if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
        # pred_df contains actual prediction, need to map back to index of classes to match
        # `compute_assertions_for_decision`
        preds_np_array = pred_df["prediction"].map(target_map).values
        assertions_metrics = compute_assertions_for_decision(preds_np_array, assertions, target_map)
    elif prediction_type == doctor_constants.REGRESSION:
        preds_np_array = pred_df["prediction"].values
        assertions_metrics = compute_assertions_for_regression(preds_np_array, assertions)
    else:
        # TODO @deepHub add evaluation recipe
        raise safe_exception(ValueError, u"Unknown prediction type: {}".format(prediction_type))

    logger.info("Done evaluating assertions")
    formatted_assertions_metrics = {metric.name: metric.to_dict(with_name=False) for metric in assertions_metrics}
    return pd.DataFrame({"passingAssertionsRatio": [assertions_metrics.passing_assertions_ratio],
                         "assertionsMetrics": [dkujson.dumps(formatted_assertions_metrics)]})


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()
    with ErrorMonitoringWrapper():
        runner = RegularEvaluateRecipe(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5],
                                       dkujson.load_from_filepath(sys.argv[6]),
                                       dkujson.load_from_filepath(sys.argv[7]),
                                       dkujson.load_from_filepath(sys.argv[8]),
                                       dkujson.load_from_filepath(sys.argv[9]),
                                       dkujson.load_from_filepath(sys.argv[10]),
                                       sys.argv[11],
                                       evaluation_dataset_type=sys.argv[12],
                                       api_node_logs_config=dkujson.loads(sys.argv[13]),
                                       diagnostics_folder=sys.argv[14],
                                       fmi=sys.argv[15])
        log_nvidia_smi_if_use_gpu(recipe_desc=runner.recipe_desc)
        runner.run()
