from __future__ import unicode_literals

import re
from abc import ABCMeta
from enum import Enum

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku.core import dkujson
from dataiku.doctor.prediction.overrides.overrides_metrics import OverridesMetrics
from dataiku.doctor.prediction.overrides.utils import dict_with_tuple_keys_to_nested_dict
from dataikuscoring.utils.prediction_result import AbstractPredictionResult
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult

DEFAULT_OVERRIDE_VALUE = "__DKU__DEFAULT__OVERRIDE__VALUE__"
APPLIED_RULE = "appliedRule"
RAW_RESULT = "rawResult"
PREDICTION_CHANGED = "predictionChanged"
RULE_MATCHED = "ruleMatched"
RULE_POLICY = "rulePolicy"


class RulePolicy(Enum):
    DECLINED = 0
    ENFORCED = 1


class OverridesResultsMixin(object):

    def __init__(self, overrides_flags, raw_prediction_result, overridden_preds, overrides_names, declined_mask):
        """
        :type overrides_flags: pd.Series
        :type raw_prediction_result: AbstractPredictionResult
        :type overridden_preds: np.ndarray
        :type overrides_names: list[str]
        """
        self.overrides_flags = overrides_flags
        self._raw_prediction_result = raw_prediction_result
        self._overridden_preds = overridden_preds  # Used for computing metrics
        self.overrides_names = overrides_names
        self._overrides_metrics = None
        self._declined_mask = declined_mask if declined_mask is not None else np.full(overridden_preds.shape, False)

    @staticmethod
    def series_to_json(series):
        """
        :type series: pd.Series
        :rtype: str
        """
        series_dict = series.to_dict()
        flattened_dict = dict_with_tuple_keys_to_nested_dict(series_dict)
        return dkujson.dumps(flattened_dict)

    def compute_and_return_info_column(self):
        """
        :rtype: pd.Series
        :return: returns a series with the override information per row,
        including the matched override and the raw predictions, and probas (for classification)
        """

        # Building a `matched_override_info_df` that will contain multiple level columns to represent the hierarchy
        # in the final json, Those levels are represented as tuples when joined with other dataframes.
        # The final `matched_override_info_df` looks like, for classification:
        #   (rawResult,       (rawResult,       (rawResult,
        #                      (probabilities,   (probabilities,
        #       prediction)           class0))          class1))   appliedRule   predictionChanged   ruleMatched
        # 1          class0           0.623075          0.376925    Override 0               False          True
        # 5          class0           0.624901          0.375099    Override 1                True          True
        # 8          class1           0.446077          0.553923    Override 0               False          True
        raw_pred_df = self._raw_prediction_result.as_dataframe(for_json_serialization=True)
        raw_pred_df.index = self.overrides_flags.index
        prepend_column_level_inplace(raw_pred_df, RAW_RESULT)

        overrides_flags_array = self.overrides_flags.values
        matched_an_override_mask = overrides_flags_array != DEFAULT_OVERRIDE_VALUE
        actually_overridden_flags = overrides_flags_array[matched_an_override_mask]

        matched_override_info_index = self.overrides_flags.index[matched_an_override_mask]
        matched_override_info_df = pd.DataFrame({
            APPLIED_RULE: actually_overridden_flags,
            PREDICTION_CHANGED: (self._raw_prediction_result.preds[matched_an_override_mask]
                                 != self._overridden_preds[matched_an_override_mask]),
            # More details in sc-126633 as to why we use this syntax rather than just True
            RULE_MATCHED: np.full(matched_override_info_index.shape, True),
            RULE_POLICY: np.where(self._declined_mask[matched_an_override_mask],
                                  RulePolicy.DECLINED.name, RulePolicy.ENFORCED.name)
        }, index=matched_override_info_index)

        matched_override_info_df = pd.concat([raw_pred_df, matched_override_info_df], axis=1, join="inner")

        matched_override_info_series = matched_override_info_df.apply(lambda series: self.series_to_json(series),
                                                                      axis=1)
        # Adding all not-override info
        return matched_override_info_series.reindex(self.overrides_flags.index,
                                                    fill_value=dkujson.dumps({RULE_MATCHED: False}))

    def compute_and_return_overrides_metrics(self):
        """
        :rtype: OverridesMetrics
        """
        if self._overrides_metrics is None:
            overrides_flags_array = self.overrides_flags.values
            matched_an_override_mask = overrides_flags_array != DEFAULT_OVERRIDE_VALUE
            actually_overridden_flags = overrides_flags_array[matched_an_override_mask]
            changed_preds_on_overridden_series = pd.Series(self._raw_prediction_result.preds[matched_an_override_mask]
                                                           != self._overridden_preds[matched_an_override_mask])
            # The series contains 1 if the row was changed, 0 otherwise. Therefore:
            #   * the count of rows per override will yield the number of matched rows for this override
            #   * the sum of the rows will yield the number of changed rows for this override
            metrics_df = changed_preds_on_overridden_series.groupby(actually_overridden_flags).agg(["count", "sum"])
            metrics_df.rename({"count": "nb_matching", "sum": "nb_changed"}, axis=1, inplace=True)
            self._overrides_metrics = OverridesMetrics.from_metrics_df(self.overrides_names, metrics_df)
        return self._overrides_metrics


def prepend_column_level_inplace(df, column_level_name):
    """
    :type df: pd.DataFrame
    :type column_level_name: str
    :return: Nothing
    """
    df.columns = pd.MultiIndex.from_product([[column_level_name], df.columns])


@add_metaclass(ABCMeta)
class AbstractOverridesResults(OverridesResultsMixin, AbstractPredictionResult):

    def align_with_not_declined(self, array):
        """
        This method will filter only the accepted rows, the rows not dropped by any override using the decline policy

        :type array: np.ndarray
        :rtype: np.ndarray
        """
        if array is None:
            return None
        if self._declined_mask is None:
            return array
        return array[~self._declined_mask]

    def assert_not_all_declined(self):
        assert not self._declined_mask.all(), "All predictions have been declined, please check your overrides definition"


class OverriddenClassificationPredictionResult(ClassificationPredictionResult, AbstractOverridesResults):

    def __init__(self, raw_prediction_result, overrides_flags, overrides_names, preds, probas=None, declined_mask=None):
        """
        :type raw_prediction_result: ClassificationPredictionResult
        :type overrides_flags: pd.Series
        :type overrides_names: list[str]
        :type preds: np.ndarray
        :type probas: np.ndarray or None
        :type declined_mask: np.ndarray or None
        """
        ClassificationPredictionResult.__init__(self, raw_prediction_result.target_map, probas=probas, preds=preds)
        OverridesResultsMixin.__init__(self, overrides_flags, raw_prediction_result, preds, overrides_names,
                                       declined_mask)
        self.raw_prediction_result = raw_prediction_result

    @property
    def probas_not_declined(self):
        return self.align_with_not_declined(self.probas)

    @staticmethod
    def _concat(prediction_results):
        """
        :type prediction_results: list[OverriddenClassificationPredictionResult]
        :rtype: OverriddenClassificationPredictionResult
        """
        if len(prediction_results) == 0:
            raise ValueError("Cannot concat 0 results")

        overrides_flags_concat = pd.concat([pr.overrides_flags for pr in prediction_results])
        raw_prediction_result_concat = AbstractPredictionResult.concat([pr.raw_prediction_result
                                                                        for pr in prediction_results])
        overrides_names = prediction_results[0].overrides_names
        overridden_result_concat = ClassificationPredictionResult._concat(prediction_results)

        declined_mask_concat = np.concatenate([pr._declined_mask for pr in prediction_results])
        assert isinstance(raw_prediction_result_concat,
                          ClassificationPredictionResult), "Wrong type for concatenated result"
        return OverriddenClassificationPredictionResult(raw_prediction_result_concat, overrides_flags_concat,
                                                        overrides_names, overridden_result_concat.preds,
                                                        overridden_result_concat.probas, declined_mask_concat)


class OverriddenPredictionResults(PredictionResult, AbstractOverridesResults):

    def __init__(self, raw_prediction_result, overrides_flags, overrides_names, preds, declined_mask=None,
                 prediction_intervals=None):
        """
        :type raw_prediction_result: PredictionResult
        :type overrides_flags: pd.Series
        :type overrides_names: list[str]
        :type preds: np.ndarray
        :type declined_mask: np.ndarray or None
        :type prediction_intervals: np.ndarray or None
        """
        PredictionResult.__init__(self, preds, prediction_intervals)
        OverridesResultsMixin.__init__(self, overrides_flags, raw_prediction_result, preds, overrides_names,
                                       declined_mask)
        self.raw_prediction_result = raw_prediction_result

    @staticmethod
    def _concat(prediction_results):
        """
        :type prediction_results: list[OverriddenPredictionResults]
        :rtype: OverriddenPredictionResults
        """
        if len(prediction_results) == 0:
            raise ValueError("Cannot concat 0 results")

        overrides_flags_concat = pd.concat([pr.overrides_flags for pr in prediction_results])
        raw_prediction_result_concat = AbstractPredictionResult.concat([pr.raw_prediction_result
                                                                        for pr in prediction_results])
        overrides_names = prediction_results[0].overrides_names
        preds_concat = np.concatenate([pr.preds for pr in prediction_results])
        declined_mask_concat = np.concatenate([pr._declined_mask for pr in prediction_results])
        if prediction_results[0].has_prediction_intervals():
            intervals_concat = np.concatenate([pr._prediction_intervals for pr in prediction_results])
        else:
            intervals_concat = None
        assert isinstance(raw_prediction_result_concat, PredictionResult), "Wrong type for concatenated result"
        return OverriddenPredictionResults(raw_prediction_result_concat, overrides_flags_concat, overrides_names,
                                           preds_concat, declined_mask_concat, intervals_concat)
