import logging
import numpy as np
import pandas as pd

from dataiku.base.utils import safe_unicode_str, get_argspec
from dataiku.doctor.exception import CustomMetricException
from dataiku.doctor.prediction.custom_scoring import check_customscore, python2_friendly_exec
from dataiku.doctor.diagnostics import diagnostics


logger = logging.getLogger(__name__)

def _get_custom_scorefunc(code):
    dic = {}
    if not code:
        raise ValueError("You must write the custom metric code")
    python2_friendly_exec(code, dic, dic)
    if "score" not in dic:
        raise ValueError("Custom evaluation function not defined")
    fn = dic["score"]

    def _wrapped(y_valid, preds_or_probas, eval_df, output_df, ref_sample_df=None, sample_weight=None, **kwargs):
        try:
            argspec = get_argspec(fn)

            if 'sample_weight' in argspec[0]:
                val = fn(y_valid, preds_or_probas, eval_df, output_df, ref_sample_df, sample_weight, **kwargs)
            else:
                val = fn(y_valid, preds_or_probas, eval_df, output_df, ref_sample_df, **kwargs)
        except Exception as e:
            logger.exception("Custom scoring function failed")
            raise CustomMetricException("Custom scoring function failed: %s" % e)
        check_customscore(val, allow_naninf=False)
        return val

    def type_wrapper(y_valid, preds_or_probas, eval_df, output_df, ref_sample_df=None, sample_weight=None, **kwargs):
        # https://app.shortcut.com/dataiku/story/138579
        # Useful for custom metric evaluation backwards compatibility
        cast_ypred = None
        if isinstance(preds_or_probas, np.ndarray):
            def cast_ypred(arr): return pd.Series(arr)

        try:
            val = _wrapped(y_valid, preds_or_probas, eval_df, output_df, ref_sample_df, sample_weight, **kwargs)
        except (CustomMetricException, ValueError) as e:
            if cast_ypred:
                logger.warning("Evaluation failed with preds_or_probas as np.array. Try as pd.Series")
                try:
                    val = _wrapped(y_valid, cast_ypred(preds_or_probas), eval_df, output_df, ref_sample_df, sample_weight, **kwargs)
                    logger.warning("Evaluation of custom metric recovered from failure with preds_or_probas as pd.Series successful. Please consider updating your metric definition.")
                except Exception as pd_series_exception:
                    logger.error("Evaluation failed with fallback preds_or_probas as pd.Series: %s" % (pd_series_exception))
                    raise e
            else:
                raise e
        except TypeError as e:
            logger.error("Unexpected type for score: %s" % e)
            raise e
        return val
    return type_wrapper


def _execute_parsed_custom_metric_function(
        custom_metric_function,
        custom_metric_result,
        valid_y,
        preds_or_probas,
        eval_df,
        output_df,
        ref_sample_df=None,
        sample_weight=None,
        treat_metrics_failure_as_error=True):
    custom_metric = custom_metric_result['metric']
    try:
        res = custom_metric_function(valid_y, preds_or_probas, eval_df, output_df, ref_sample_df, sample_weight)

        custom_metric_result["value"] = res
        custom_metric_result["didSucceed"] = True
    except Exception as e:
        if treat_metrics_failure_as_error:
            raise e
        else:
            custom_metric_result["didSucceed"] = False
            custom_metric_result['error'] = safe_unicode_str(e)
            logger.warning("Custom metric function '{}' failed to execute".format(custom_metric['name']), exc_info=True)

            diagnostics.add_or_update(
                diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                "Calculation of '{}' failed".format(custom_metric_result['metric']['name'])
            )

    return custom_metric_result


def calculate_custom_evaluation_metrics(
        custom_evaluation_metrics,
        valid_y,
        preds,
        probas,
        eval_df,
        output_df,
        ref_sample_df,
        sample_weight,
        is_classification,
        treat_metrics_failure_as_error):
    """
    Calculates custom metric values at evaluation time. Does not perform any per-cut calculations.
    """
    custom_metrics_results = []
    for custom_metric in custom_evaluation_metrics:
        if is_classification and custom_metric.get("needsProbability", False):
            custom_metric_result = _parse_and_compute_custom_evaluation_metric(
                custom_metric,
                valid_y,
                probas,
                eval_df,
                output_df,
                ref_sample_df,
                sample_weight,
                treat_metrics_failure_as_error)
        else:
            custom_metric_result = _parse_and_compute_custom_evaluation_metric(
                custom_metric,
                valid_y,
                preds,
                eval_df,
                output_df,
                ref_sample_df,
                sample_weight,
                treat_metrics_failure_as_error)

        custom_metrics_results.append(custom_metric_result)
    return custom_metrics_results

def _parse_and_compute_custom_evaluation_metric(custom_metric, valid_y, preds_or_probas, eval_df, output_df, ref_sample_df, sample_weight, treat_metrics_failure_as_error):
    custom_metric_result = {
        'metric': custom_metric
    }

    try:
        custom_metric_function = _get_custom_scorefunc(custom_metric["metricCode"])
    except Exception as e:
        if treat_metrics_failure_as_error:
            raise e
        else:
            custom_metric_result["didSucceed"] = False
            custom_metric_result['error'] = safe_unicode_str(e)
            logger.warning("Custom metric function '{}' failed to parse".format(custom_metric['name']), exc_info=True)

            diagnostics.add_or_update(
                diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                "Calculation of '{}' failed: unable to parse metric code".format(custom_metric_result['metric']['name'])
            )

            return custom_metric_result

    return _execute_parsed_custom_metric_function(
        custom_metric_function,
        custom_metric_result,
        valid_y,
        preds_or_probas,
        eval_df,
        output_df,
        ref_sample_df,
        sample_weight=sample_weight,
        treat_metrics_failure_as_error=treat_metrics_failure_as_error)
