import logging

from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.evaluation.base import is_classification, recreate_preds_from_probas, proba_definitions_from_probas
from dataiku.doctor.exception import CustomMetricException
from dataiku.doctor.prediction.custom_scoring import check_customscore, python2_friendly_exec
from dataiku.doctor.diagnostics import diagnostics


logger = logging.getLogger(__name__)

class ModelParametersForCustomMetric(object):
    has_model = True
    prediction_type = None
    prediction_column_name = None
    dont_compute_performance = False
    target_column_name = None
    weight_column_name = None
    user_defined_classes = []
    is_proba_aware = None
    proba_definition = [] # list of {'key': something, 'value': something_else}
    _fake_modeling_params = {}

    def __init__(self, prediction_type, recipe_desc, fake_modeling_params):
        self.prediction_type = prediction_type
        self.has_model = recipe_desc.get('hasModel', True)
        if self.has_model:
            self._fake_modeling_params = fake_modeling_params
            self.prediction_column_name = recipe_desc.get('predictionVariable', None)
            self.is_proba_aware = recipe_desc.get('isProbaAware', False) and is_classification(self.prediction_type)
            if self.is_proba_aware:
                self.proba_definition = proba_definitions_from_probas(recipe_desc.get('probas', []))
                self.user_defined_classes = [kv['key'] for kv in self.proba_definition]
            else:
                self.user_defined_classes = recipe_desc.get('classes', [])
            self.dont_compute_performance = recipe_desc.get('dontComputePerformance', False)
            if not self.dont_compute_performance:
                self.target_column_name = recipe_desc.get('targetVariable', None)
                self.weight_column_name = recipe_desc.get('weightsVariable', None)

    def get_preds(self, df):
        """
        get the predictions from a reference or evaluation dataset
        if the Model is 'Proba aware', the predictions are re-computed from the probabilities
        :param df: pandas Dataframe from which to get the predictions
        :return:
        """
        if not self.has_model:
            return None

        if self.is_proba_aware and not self.dont_compute_performance:
            proba_columns = [kv['value'] for kv in self.proba_definition]
            probas = df[proba_columns].values
            return recreate_preds_from_probas(
                probas,
                self.target_column_name,
                df[self.target_column_name],
                self._get_proba_target_map(),
                self.prediction_type,
                self._fake_modeling_params)
        else:
            return df[self.prediction_column_name]

    def _get_proba_target_map(self):
        classes = [kv['key'] for kv in self.proba_definition]
        return {value: idx for idx, value in enumerate(classes)}


def calculate_custom_standalone_evaluation_metrics(
        custom_evaluation_metrics,
        eval_df,
        ref_df,
        model_parameters,
        treat_metrics_failure_as_error):
    """
    Calculates custom metric values at evaluation time. Does not perform any per-cut calculations.
    """
    custom_metrics_results = []
    for custom_metric in custom_evaluation_metrics:
        custom_metric_result = _parse_and_compute_custom_evaluation_metric(
            custom_metric,
            eval_df,
            ref_df,
            model_parameters,
            treat_metrics_failure_as_error)
        custom_metrics_results.append(custom_metric_result)
    return custom_metrics_results

def _parse_and_compute_custom_evaluation_metric(custom_metric, eval_df, ref_df, model_parameters, treat_metrics_failure_as_error):
    custom_metric_result = {
        'metric': custom_metric
    }

    try:
        custom_metric_function = _get_custom_scorefunc(custom_metric["metricCode"])
    except Exception as e:
        if treat_metrics_failure_as_error:
            raise e
        else:
            custom_metric_result["didSucceed"] = False
            custom_metric_result['error'] = safe_unicode_str(e)
            logger.warning("Custom metric function '{}' failed to parse".format(custom_metric['name']), exc_info=True)

            diagnostics.add_or_update(
                diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                "Calculation of '{}' failed: unable to parse metric code".format(custom_metric_result['metric']['name'])
            )

            return custom_metric_result

    return _execute_parsed_custom_metric_function(
        custom_metric_function,
        custom_metric_result,
        eval_df,
        ref_df,
        model_parameters,
        treat_metrics_failure_as_error=treat_metrics_failure_as_error)


def _get_custom_scorefunc(code):
    dic = {}
    if not code:
        raise ValueError("You must write the custom metric code")
    python2_friendly_exec(code, dic, dic)
    if "score" not in dic:
        raise ValueError("Custom evaluation function not defined")
    fn = dic["score"]

    def try_wrapper(eval_df, ref_df, model_parameters, **kwargs):
        try:
            val = fn(eval_df, ref_df, model_parameters, **kwargs)
            check_customscore(val, allow_naninf=False)
        except TypeError as e:
            logger.error("Unexpected type for score: %s" % e)
            raise e
        except Exception as e:
            logger.exception("Custom scoring function failed")
            raise CustomMetricException("Custom scoring function failed: %s" % e)
        return val
    return try_wrapper


def _execute_parsed_custom_metric_function(
        custom_metric_function,
        custom_metric_result,
        eval_df,
        ref_df,
        model_parameters,
        treat_metrics_failure_as_error=True):
    custom_metric = custom_metric_result['metric']
    try:
        res = custom_metric_function(eval_df, ref_df, model_parameters)
        custom_metric_result["value"] = res
        custom_metric_result["didSucceed"] = True
    except Exception as e:
        if treat_metrics_failure_as_error:
            raise e
        else:
            custom_metric_result["didSucceed"] = False
            custom_metric_result['error'] = safe_unicode_str(e)
            logger.warning("Custom metric function '{}' failed to execute".format(custom_metric['name']), exc_info=True)

            diagnostics.add_or_update(
                diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                "Calculation of '{}' failed".format(custom_metric_result['metric']['name'])
            )

    return custom_metric_result
