from copy import deepcopy

import logging
import math
from abc import ABCMeta
from abc import abstractmethod

import numpy as np
from six import add_metaclass
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor

from dataiku.doctor.prediction.linear_coefficients_computation import Denormalizer
from dataiku.doctor.utils.skcompat import dku_recursive_partial_dependence
from dataikuscoring.utils.prediction_result import AbstractPredictionResult

logger = logging.getLogger(__name__)

PERF_WITHOUT_OVERRIDES_FILENAME = "perf_without_overrides.json"
PERF_FILENAME = "perf.json"
PREDICTED_FILENAME = "predicted.csv"


class PredictionModelIntrinsicScorer(object):
    def __init__(self, modeling_params, clf, train_X, train_y, out_folder_context, prepared_X, with_sample_weight):
        self.modeling_params = modeling_params
        self.clf = clf
        self.train_X = train_X
        self.ret = {}
        self.train_y = train_y
        self.out_folder_context = out_folder_context
        self.prepared_X = prepared_X
        self.with_sample_weight = with_sample_weight

    def add_raw_feature_importance_if_exists(self, clf, ret):
        logger.info("Computing feature importance")
        coefs = compute_variables_importance(self.train_X.columns(), clf)
        if coefs:
            ret["rawImportance"] = coefs


@add_metaclass(ABCMeta)
class PredictionModelScorer(object):
    def __init__(self, modeling_params, test_unprocessed=None, test_X=None, assertions=None):
        """
        :param dict modeling_params: modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param pandas.DataFrame | None test_unprocessed: The "UNPROCESSED" value returned from processing the test dataset via pipeline.process().
        Required for the custom metric x_valid parameter.
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
        If None, no data will be written on disk (e.g. training recipes).
        :param dataiku.doctor.preprocessing.assertions.MLAssertions | None assertions: collection of assertions based on ML performance metrics
        """
        self.modeling_params = modeling_params
        self.test_unprocessed = test_unprocessed
        self.scorer_without_overrides = None
        self.test_prediction_result = None

        self.test_X = test_X
        if test_X is None:
            self.test_X_columns = None
            self.test_X_index = None
        else:
            self.test_X_columns = test_X.columns()
            self.test_X_index = test_X.index

        self.assertions = assertions
        self.ret = {
            "metrics": {},
            "processed_feature_names": self.test_X_columns  # Note that we use the test set for convenience here, but preprocessed
                                                            # test set columns are the same as preprocessed train set columns since
                                                            # the pipeline is fitted on the train set.
        }
        self.out_folder_context = None
        self.predicted_df = None
        self.perf_data = None

    def add_metric(self, measure, value, description=""):
        self.ret["metrics"][measure] = {'measure': measure, 'value': value, 'description': description}

    def assert_score_called(self):
        assert self.perf_data is not None, "Scoring data is not yet available"

    def score(self, with_assertions=True, treat_metrics_failure_as_error=True):
        if self.scorer_without_overrides is not None:
            logger.info("Start scoring predictions prior to being overridden")
            self.scorer_without_overrides.score(with_assertions, treat_metrics_failure_as_error)
            logger.info("Done scoring pre-overrides predictions, will now score overridden ones")
        return self._do_score(with_assertions, treat_metrics_failure_as_error)

    def save(self, dump_predicted=True):
        self.assert_score_called()
        assert self.out_folder_context is not None, "Missing output folder to save predictions and performance"
        if dump_predicted and self.predicted_df is not None:
            with self.out_folder_context.get_file_path_to_write(PREDICTED_FILENAME) as predicted_file_path:
                self.predicted_df.to_csv(predicted_file_path, sep="\t", header=True, index=False, encoding='utf-8')
        self.out_folder_context.write_json(PERF_FILENAME, self.perf_data)
        if self.scorer_without_overrides is not None:
            self.out_folder_context.write_json(PERF_WITHOUT_OVERRIDES_FILENAME, self.scorer_without_overrides.perf_data)

    @abstractmethod
    def _do_score(self, with_assertions, treat_metrics_failure_as_error=True):
        pass


@add_metaclass(ABCMeta)
class ClassicalPredictionModelScorer(PredictionModelScorer):
    def __init__(self, modeling_params, out_folder_context, align_with_not_declined, test_y, test_unprocessed=None,
                 test_X=None, test_df_index=None, test_sample_weight=None, assertions=None):
        """
        :param dict modeling_params: modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param str out_folder: directory where predicted data and perf.json will be written
        :param align_with_not_declined: Method to call to align the indexes of to accepted predictions
        :param pandas.Series test_y: 1-dimensional array representing the ground truth target on the test set
        :param pandas.DataFrame | None test_unprocessed: The "UNPROCESSED" value returned from processing the test dataset via pipeline.process().
        Required for the custom metric x_valid parameter.
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
        If None, no data will be written on disk (e.g. training recipes).
        :param test_df_index: Pandas index of the input dataframe of the original test set, prior to any processing
        :param Series test_sample_weight: 1-dimensional array representing sample weights on the test set
        :param MLAssertions assertions: collection of assertions based on ML performance metrics
        """
        super(ClassicalPredictionModelScorer, self).__init__(modeling_params, test_unprocessed, test_X, assertions)

        self.out_folder_context = out_folder_context
        self.test_y = align_with_not_declined(test_y)
        self.test_df_index = test_df_index
        self.test_sample_weight = align_with_not_declined(test_sample_weight)

        self.test_unprocessed = align_with_not_declined(self.test_unprocessed)
        self.test_X_index = align_with_not_declined(self.test_X_index)
        self.assertions = align_assertions_masks_with_not_declined(self.assertions, align_with_not_declined)


@add_metaclass(ABCMeta)
class BaseCVModelScorer(object):
    DISCARDED_METRICS_FOR_AGG = {"overridesMetrics"}  # cannot or should not be aggregated

    def __init__(self, scorers):
        """
        :type scorers: list[PredictionModelScorer]
        """
        self.scorers = scorers
        self.perfdatas = [x.perf_data for x in scorers]
        self.test_prediction_result = AbstractPredictionResult.concat([x.test_prediction_result for x in scorers])

        self._scorer_without_overrides = None
        scorers_without_overrides = [scorer.scorer_without_overrides for scorer in scorers]
        if all([sc is not None for sc in scorers_without_overrides]):
            self._scorer_without_overrides = self.__class__(scorers_without_overrides)

    def score(self):
        self.ret = {}

        # Compute global metrics (mean of all folds)
        logger.info("Computing global metrics")

        fold_metrics = [ perf["globalMetrics"] for perf in self.perfdatas ]
        global_metrics = fold_metrics[0]

        for key in global_metrics:
            if isinstance(global_metrics[key], list):
                for i in range(0, len(global_metrics[key])):
                    global_metrics[key][i] = np.mean([metric[key][i] for metric in fold_metrics])
            else:
                global_metrics[key] = np.mean([metric[key] for metric in fold_metrics])

        self.ret["globalMetrics"] = global_metrics

        return self.ret

    def score_without_overrides(self):
        if self._scorer_without_overrides is None:
            return None
        else:
            return self._scorer_without_overrides.score()


def align_assertions_masks_with_not_declined(assertions, aligning_method):
    if assertions is None:
        return None
    aligned_assertions = deepcopy(assertions)  # Deep copy needed to make sure we don't override the original mask
    for assertion in aligned_assertions:
        assertion.mask = aligning_method(assertion.mask)
    return aligned_assertions


def trim_curve(curve, distance_threshold=0.05):
    """ Given a list of P_k=(x,y) curve points, remove points until there is no segemnt P_k , P_k+1
        that are smaller than distance_threshold. """
    curve = list(curve)
    yield curve[0]
    distance = 0
    for (prev, next) in zip(curve, curve[1:]):
        dx = next[0] - prev[0]
        dy = next[1] - prev[1]
        distance += math.sqrt(dx ** 2 + dy ** 2)
        if distance >= distance_threshold:
            yield next
            distance = 0
    if distance > 0:
        yield curve[-1]


def compute_variables_importance(features, clf):
    if not hasattr(clf, "feature_importances_"):
        logger.info("No feature importance in classifier")
        return {}

    try:
        feature_importances = clf.feature_importances_
    except AttributeError as e:
        # XGBoost + DART has a feature_importances_ attribute, but trying to access it fails
        logger.info("Not computing feature importances because attribute is present,"
                    "but failed to retrieve it, maybe XGBoost+DART")
        return {}

        # Ensure that 'feature_importances' has the appropriate format
    if not (isinstance(feature_importances, list) or isinstance(feature_importances, np.ndarray)):
        logger.info("Not computing feature importances because `feature_importances_`"
                    " has wrong format: '{}'".format(type(feature_importances)))
        return {}

    if isinstance(feature_importances, list):
        feature_importances = np.array(feature_importances)

    importances_sum = np.sum(feature_importances)

    if np.isnan(importances_sum) or importances_sum == 0.0:
        logger.info("Not computing feature importances because `feature_importances_`"
                    " sums to 0 or NaN")
        return {}

    # Rescaling importances to make them homogeneous to percentage
    # Already done in scikit learn models, but for custom/plugin
    # models, the user is free to do whatever he wants.
    feature_importances = feature_importances / float(importances_sum)

    coefs = {"variables": [], "importances": []}
    for v, i in zip(features, feature_importances):
        if i != 0.0 and not np.isnan(i):
            coefs["variables"].append(v)
            coefs["importances"].append(i)
    return coefs


def build_partial_dependence_plot(model, train_X, train_y, rescalers):
    denorm = Denormalizer(rescalers)
    feature_names = train_X.columns()
    X = train_X.as_np_array()
    offset = np.mean(train_y)

    # Compute partial dependences
    def make_pdp(i):
        feature = feature_names[i]
        try:
            pdp, axes = dku_recursive_partial_dependence(model, i, X, grid_resolution=100)
            feature_bins = [denorm.denormalize_feature_value(feature, x) for x in list(axes[0])]
            # If we computed the partial_dependence on a RandomForestRegressor or a
            # DecisionTreeRegressor, then we need to subtract the mean of `y` from
            # the predictions computed by partial_dependence so that the result is
            # centered on y=0. We don't need to do this for GBT because we assume
            # that the trees are already centered.
            if isinstance(model, (RandomForestRegressor, DecisionTreeRegressor)):
                data = [x - offset for x in pdp]
            else:
                data = pdp
        except Exception as e:
            logger.warning("Failed to compute partial dependence plot by recursive method: {}".format(e))
            feature_bins = []
            data = []
        return {"feature": feature, "featureBins": feature_bins, "data": list(data)}

    return [make_pdp(i) for i in range(0, len(feature_names))]
