# encoding: utf-8

"""
Execute an evaluation recipe in PyRegular mode
Must be called in a Flow environment
"""
from os import environ
import logging
import sys

import pandas as pd

from dataiku.doctor.evaluation.base import EvaluateRecipe, load_input_dataframe, sample_and_store_dataframe, \
    compute_metrics_df, process_input_df_skip_predict, compute_custom_evaluation_metrics_df
from dataiku.base.utils import ErrorMonitoringWrapper

from dataiku.core import debugging, doctor_constants, schema_handling
from dataiku.core import dkujson as dkujson
from dataiku import Dataset

from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.doctor.preprocessing.dataframe_preprocessing import PreprocessingPipeline, PreprocessingResult
from dataiku.doctor.utils.api_logs import SAGEMAKER_EVALUATION_DATASET_TYPE, API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE, \
    CLOUD_API_NODE_LOGS_FEATURE_PREFIX
from dataiku.doctor.utils.api_logs import decode_and_normalize_sagemaker_logs
from dataiku.doctor.utils.api_logs import API_NODE_LOGS_FEATURE_PREFIX
from dataiku.external_ml.mlflow.pyfunc_evaluate_common import process_input_df
from dataiku.external_ml.mlflow.pyfunc_read_meta import read_user_meta
from dataiku.external_ml.mlflow.pyfunc_common import load_external_model_meta
from dataiku.modelevaluation.data_types import cast_as_string

logger = logging.getLogger(__name__)


class PyfuncEvaluateRecipe(EvaluateRecipe):

    def __init__(self, model_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname,
                 recipe_desc,
                 cond_outputs=None, preprocessing_params=None, model_evaluation_store_folder=None,
                 evaluation_dataset_type=None,
                 api_node_logs_config=None, diagnostics_folder=None, fmi=None, proxy=None):

        super(PyfuncEvaluateRecipe, self).__init__(model_folder, input_dataset_smartname, None, output_dataset_smartname,
                                                   metrics_dataset_smartname,
                                                   recipe_desc, None, None, cond_outputs, preprocessing_params,
                                                   model_evaluation_store_folder,
                                                   evaluation_dataset_type, api_node_logs_config, diagnostics_folder, fmi)

        self.infer_metrics_dataset_schema = True
        self.mlflow_imported_model = None
        if proxy:
            environ["_PROXY_MODEL_PROXY"] = proxy

    def _fetch_input_dataset_and_model_params(self):
        import mlflow

        self.input_dataset = Dataset(self.input_dataset_smartname)
        self.input_dataset.preparation_requested_output_schema = {
            "columns": self.input_dataset.read_schema()}  # No preparation for mlflow models

        self.mlflow_imported_model = load_external_model_meta(self.model_folder_context)
        self.target_column_in_dataset = self.mlflow_imported_model["targetColumnName"]
        self.model_target_column = self.mlflow_imported_model["targetColumnName"]

        self.prediction_type = self.mlflow_imported_model["predictionType"]

        if "targetColumnName" not in self.mlflow_imported_model:
            raise ValueError("Model has no target column name configured")

        (self.columns, _, self.parse_date_columns) = Dataset.get_dataframe_schema_st(self.input_dataset.read_schema(),
                                                                                     parse_dates=True,
                                                                                     infer_with_pandas=False)

        self.dtypes = {self.target_column_in_dataset: object if self.prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS] else float}

        # MLflow models cannot be partitioned, but we still a bit of the partition formalism for more uniformity with
        # DSS models
        with self.model_folder_context.get_folder_path_to_read() as folder_path:
            self.mlflow_model = mlflow.pyfunc.load_model(folder_path)
        self.partition_dispatch = False
        self.with_sample_weight = False
        self.modeling_params = {
            "metrics": self.mlflow_imported_model["metricParams"],
            # Force to False in MLflow models
            "autoOptimizeThreshold": False,
            # Init to default value in the recipe
            "forcedClassifierThreshold": read_user_meta(self.model_folder_context).get("activeClassifierThreshold")
        }

        if self.prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
            self.target_mapping = {
                label: int(class_id)
                for label, class_id in self.mlflow_imported_model["labelToIntMap"].items()
            }

    def _get_input_df(self):
        input_df = load_input_dataframe(
            input_dataset=self.input_dataset,
            sampling=self.recipe_desc.get('selection', {"samplingMethod": "FULL"}),
            columns=self.columns,
            dtypes=self.dtypes,
            parse_date_columns=self.parse_date_columns,
        )

        if self.evaluation_dataset_type == SAGEMAKER_EVALUATION_DATASET_TYPE:
            logs_df_decoded = decode_and_normalize_sagemaker_logs(input_df, self.mlflow_imported_model, self.prediction_type, self.target_column_in_dataset)
            return logs_df_decoded
        else:
            return input_df

    def _get_sample_dfs(self, input_df):
        schema = {"columns": schema_handling.get_schema_from_df(input_df)}
        sample_input_df = sample_and_store_dataframe(self.model_evaluation_store_folder_context, input_df, schema, limit_sampling=self.recipe_desc.get('limitSampling', True))

        sample_input_df_copy_unnormalized = sample_input_df.copy()

        if self.recipe_desc.get('skipScoring', False):
            if not self.dont_compute_performance:
                sample_input_df = sample_input_df.dropna(subset=[self.model_target_column])
            sample_pred_df = sample_input_df[[self.prediction_column] + self.proba_columns]

        else:
            sample_pred_df, _, _, _, _, _ = process_input_df(
                self._partitions_generator(sample_input_df), self.mlflow_imported_model, self.partition_dispatch,
                self.modeling_params, self.with_sample_weight,
                self.recipe_desc, self.cond_outputs, self.model_evaluation_store_folder_context, True)

        clean_kept_columns = [c for c in sample_input_df_copy_unnormalized.columns if c not in sample_pred_df.columns]
        sample_output_df = sample_input_df_copy_unnormalized[clean_kept_columns].merge(sample_pred_df, left_index=True, right_index=True, how='outer')

        return sample_input_df, sample_output_df

    def _compute_output_and_pred_df(self, input_df, input_df_copy_unnormalized):
        if self.recipe_desc.get('skipScoring', False):
            # Recreate pipeline for mlflow to reuse the common method
            class PipelineMlflow(PreprocessingPipeline):
                target = None
                target_column_name = None
                dont_compute_performance = None

                def __init__(self, steps, target_column_name, dont_compute_performance):
                    super(PipelineMlflow, self).__init__(steps)
                    self.target_column_name = target_column_name
                    self.dont_compute_performance = dont_compute_performance

                def process(self, input_df, retain=None):
                    res = PreprocessingResult()
                    res["UNPROCESSED"] = input_df.dropna(subset=[self.target_column_name]) if not self.dont_compute_performance else input_df
                    res["TRAIN"] = None
                    res["assertions"] = None
                    res["weight"] = None
                    res["target"] = self.target
                    return res

                def fit_and_process(self, input_df, *args, **kwargs):
                    raise NotImplementedError

                def report_fit(self, ret_obj, core_params):
                    raise NotImplementedError

                def set_target(self, target):
                    self.target = target

            pipeline = PipelineMlflow(steps=[], target_column_name=self.model_target_column, dont_compute_performance=self.dont_compute_performance)
            if not self.dont_compute_performance:
                target = input_df[self.model_target_column].dropna()
                if self.prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                    target = cast_as_string(target).map(self.mlflow_imported_model["labelToIntMap"])
                pipeline.set_target(target)

            pred_df, self.y, self.unprocessed, self.sample_weight, _ = process_input_df_skip_predict(
                input_df, self.model_folder_context, pipeline, self.modeling_params, self.mlflow_imported_model["labelToIntMap"],
                self.prediction_type, self.prediction_column, self.proba_columns, None, self.dont_compute_performance,
                self.recipe_desc, None, self.model_evaluation_store_folder_context)
        else:
            pred_df, self.y, self.unprocessed, self.sample_weight, self.modeling_params, _ = process_input_df(
                self._partitions_generator(input_df), self.mlflow_imported_model, self.partition_dispatch,
                self.modeling_params, self.with_sample_weight,
                self.recipe_desc, self.cond_outputs, self.model_evaluation_store_folder_context, self.dont_compute_performance)

        if self.evaluation_dataset_type == SAGEMAKER_EVALUATION_DATASET_TYPE:
            return input_df, pred_df
        else:
            return self._get_output_from_pred(input_df_copy_unnormalized, pred_df), pred_df

    def _fix_output_dataset_schema(self, output_dataset, output_df):
        """
        MLflow models do not have the output schema computed before the ER run. As we load the input dataset, we infer
        types to have under a pandas dataframe. In order not to use the inferred types for the feature columns, but to
        use the actual types from the input dataset, we combine both schema : keep the original input schema and add
        the new extra columns of the output schema.
        .e.g : a bigint target column in the input dataset -> loaded in pandas as object in the classification case
        -> stored as string back in the output dataset if we do not enforce the schema of the column
        """
        try:
            input_df_schema_per_col_name = {column["name"]: column for column in self.input_dataset.read_schema()}
            output_df_schema = schema_handling.get_schema_from_df(output_df)
            for (index, column) in enumerate(output_df_schema):
                col_name = column["name"]
                if col_name in input_df_schema_per_col_name:
                    output_df_schema[index] = input_df_schema_per_col_name[col_name]
            output_dataset.write_schema(output_df_schema)
        except Exception as e:
            logger.warning("Error while trying to write output dataset schema : %s" % str(e))
            self.infer_output_dataset_schema = True

    def _compute_metrics_df(self, output_df, pred_df):
        output_probabilities = self.recipe_desc.get("outputProbabilities")
        if self.mlflow_imported_model["predictionType"] in [doctor_constants.BINARY_CLASSIFICATION,
                                                            doctor_constants.MULTICLASS]:
            proba_columns = ["proba_{}".format(label_value) for label_value in
                             self.mlflow_imported_model["labelToIntMap"].keys()]
            if not pd.Series(proba_columns).isin(pred_df.columns).all():
                logger.info(
                    "Non probabilistic classification model detected, disabling probabilities during metrics "
                    "computation")
                output_probabilities = False

        return compute_metrics_df(self.mlflow_imported_model["predictionType"],
                                  self.mlflow_imported_model.get("labelToIntMap"),
                                  self.modeling_params, output_df, self.recipe_desc.get("metrics"),
                                  self.recipe_desc.get("customMetrics"),
                                  self.y, self.unprocessed, output_probabilities, self.sample_weight,
                                  treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

    def _compute_custom_evaluation_metrics_df(self, output_df, pred_df, unprocessed_input_df, ref_sample_df):
        output_probabilities = self.recipe_desc.get("outputProbabilities")
        if self.mlflow_imported_model["predictionType"] in [doctor_constants.BINARY_CLASSIFICATION,
                                                            doctor_constants.MULTICLASS]:
            proba_columns = ["proba_{}".format(label_value) for label_value in
                             self.mlflow_imported_model["labelToIntMap"].keys()]
            if not pd.Series(proba_columns).isin(pred_df.columns).all():
                logger.info(
                    "Non probabilistic classification model detected, disabling probabilities during metrics "
                    "computation")
                output_probabilities = False
        metrics_df = compute_custom_evaluation_metrics_df(
            output_df,
            self.recipe_desc.get("customEvaluationMetrics"),
            self.mlflow_imported_model["predictionType"],
            self.mlflow_imported_model.get("labelToIntMap"),
            self.y,
            unprocessed_input_df,
            ref_sample_df,
            output_probabilities,
            self.sample_weight,
            treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

        return metrics_df

    def _partitions_generator(self, df):
        yield df, (self.model_folder_context, self.mlflow_model)

    def _fix_feature_and_target_dtypes(self):
        if self.evaluation_dataset_type == API_NODE_EVALUATION_DATASET_TYPE:
            self.dtypes = {API_NODE_LOGS_FEATURE_PREFIX + x: y for x, y in self.dtypes.items()}
            self.target_column_in_dataset = API_NODE_LOGS_FEATURE_PREFIX + self.target_column_in_dataset

        if self.evaluation_dataset_type == CLOUD_API_NODE_EVALUATION_DATASET_TYPE:
            self.dtypes = {CLOUD_API_NODE_LOGS_FEATURE_PREFIX + x: y for x, y in self.dtypes.items()}
            self.target_column_in_dataset = CLOUD_API_NODE_LOGS_FEATURE_PREFIX + self.target_column_in_dataset


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        runner = PyfuncEvaluateRecipe(sys.argv[1], sys.argv[2],
                                      # There is no 'managedFolderSmartId' parameter in the MLflow case , otherwise here we would have: sys.argv[3]
                                      sys.argv[4], sys.argv[5],
                                      dkujson.load_from_filepath(sys.argv[6]),
                                      # There is no 'script' parameter in the MLflow case, otherwise here we would have: dkujson.load_from_filepath(sys.argv[7]),
                                      # There is no 'preparation_output_schema' parameter in the MLflow case, otherwise here we would have: dkujson.load_from_filepath(sys.argv[8]),
                                      dkujson.load_from_filepath(sys.argv[9]),
                                      dkujson.load_from_filepath(sys.argv[10]),
                                      sys.argv[11],
                                      evaluation_dataset_type=sys.argv[12],
                                      api_node_logs_config=dkujson.loads(sys.argv[13]),
                                      diagnostics_folder=sys.argv[14],
                                      fmi=sys.argv[15],
                                      proxy=sys.argv[16])
        runner.run()
