import logging
from abc import abstractmethod
from abc import ABCMeta

import numpy as np
import pandas as pd
import six

from dataiku.doctor.prediction.explanations.engine import ExplainingEngine
from dataiku.doctor.prediction.explanations.engine import SimpleExplanationResult
from dataiku.doctor.prediction.explanations.engine import GlobalExplanationResult
from dataiku.doctor.prediction.explanations.score_to_explain import OneDimensionScoreToExplain

RANDOM_SEED = 1337

logger = logging.getLogger(__name__)


class ShapleyExplainingEngine(ExplainingEngine):

    def __init__(self, frankenstein_score_builder, shapley_explanations_extractor):
        """
        Explaining engine relying on Shapley values.
        See `build_shapley_frankenstein_masks` for more details on the internals of underlying algorithm.

        :type frankenstein_score_builder: ShapleyFrankensteinScoresBuilder
        :type shapley_explanations_extractor: ShapleyExplanationsExtractor
        """
        self._frankenstein_score_builder = frankenstein_score_builder
        self._shapley_explanations_extractor = shapley_explanations_extractor
        self._shapley_explanations_extractor.set_background_size(frankenstein_score_builder.background_rows_df.shape[0])
        self._shapley_explanations_extractor.set_columns_to_explain(frankenstein_score_builder.columns_to_explain)

        # While there are 2 frankensteins, they are computed sequentially so only need to account for 1 in the peak
        self._estimated_peak_num_cells = self._frankenstein_score_builder.get_generated_number_cells_per_frankenstein()

    def get_estimated_peak_number_cells_generated_per_row_explained(self):
        return self._estimated_peak_num_cells

    def explain(self, df):
        frankenstein_score, frankenstein_to_compare_score = self._frankenstein_score_builder.build_scores(df)
        explaining_result = self._shapley_explanations_extractor.extract_explanations(df,
                                                                                      frankenstein_score,
                                                                                      frankenstein_to_compare_score)
        return explaining_result


@six.add_metaclass(ABCMeta)
class ShapleyFrankensteinScoresBuilder(object):

    def __init__(self, background_rows_df, columns_to_explain):
        """
        Class responsible for building and scoring shapley Frankensteins

        :type background_rows_df: pd.DataFrame
        :type columns_to_explain: list[str]
        """
        self.background_rows_df = background_rows_df
        self._background_size = background_rows_df.shape[0]
        self._total_num_cols = background_rows_df.shape[1]
        self.columns_to_explain = columns_to_explain
        self._columns_to_explain_indices = [np.in1d(self.background_rows_df.columns, col_name).nonzero()[0][0]
                                            for col_name in columns_to_explain]

        self.permutations = get_permutations(self._background_size, background_rows_df.shape[1])

    @abstractmethod
    def build_scores(self, df):
        """
        :type df: pd.DataFrame
        :rtype: (ScoreToExplain, ScoreToExplain)
        """

    @abstractmethod
    def get_generated_number_cells_per_frankenstein(self):
        """
        :rtype: int
        """


@six.add_metaclass(ABCMeta)
class ShapleyExplanationsExtractor(object):

    def __init__(self):
        self._columns_to_explain = None
        self._background_size = None

    def set_columns_to_explain(self, columns_to_explain):
        self._columns_to_explain = columns_to_explain

    def set_background_size(self, background_size):
        self._background_size = background_size

    @abstractmethod
    def extract_explanations(self, observations_df, baseline_score, swapped_columns_score):
        """
        :type observations_df: pd.DataFrame
        :type baseline_score: ScoreToExplain
        :type swapped_columns_score: ScoreToExplain
        :rtype: ExplainingResult
        """


class DefaultShapleyFrankensteinScoresBuilder(ShapleyFrankensteinScoresBuilder):

    def __init__(self, background_rows_df, columns_to_explain, score_computer):
        """
        :type background_rows_df: pd.DataFrame
        :type columns_to_explain: list[str]
        :type score_computer: dataiku.doctor.prediction.explanations.engine.ScoreComputer
        """
        super(DefaultShapleyFrankensteinScoresBuilder, self).__init__(background_rows_df, columns_to_explain)
        self._score_computer = score_computer

        self._frankenstein_builder, self._frankenstein_to_compare_builder = build_shapley_frankenstein_masks(
            self._total_num_cols, self._columns_to_explain_indices, self._background_size, self.permutations)

    def get_generated_number_cells_per_frankenstein(self):
        logger.info("%s cells will be generated by %s for each row" % (self._frankenstein_builder.get_size(),
                                                                       self._frankenstein_builder))
        return self._frankenstein_builder.get_size()

    def _get_frankenstein_score(self, df, df_score, frankenstein_builder):
        matching_indices_in_df, frankenstein_arr = frankenstein_builder.build_frankenstein_array(df.values,
                                                                                                 self.background_rows_df.values)
        frankenstein_df = pd.DataFrame(data=frankenstein_arr, columns=df.columns).astype(df.dtypes)
        logger.info("Built frankenstein of shape {}".format(frankenstein_df.shape))
        return self._score_computer(frankenstein_df, df_score, matching_indices_in_df)

    def build_scores(self, df):
        df_score = self._score_computer(df)
        frankenstein_score = self._get_frankenstein_score(df, df_score, self._frankenstein_builder)
        frankenstein_to_compare_score = self._get_frankenstein_score(df, df_score, self._frankenstein_to_compare_builder)
        return frankenstein_score, frankenstein_to_compare_score


class Preprocessor(object):

    def __call__(self, df):
        """
        :type df: pd.DataFrame
        :rtype: np.ndarray
        """


class PreprocessShapleyFrankensteinScoresBuilder(ShapleyFrankensteinScoresBuilder):

    def __init__(self, background_rows_df, columns_to_explain,
                 preprocess, score_computer_from_preprocess, features_to_cols_mapping):
        """
        :type background_rows_df: pd.DataFrame
        :type columns_to_explain: list[str]
        :type preprocess: Preprocessor
        :type score_computer_from_preprocess: ScoreComputer
        :type features_to_cols_mapping: np.ndarray
        """
        super(PreprocessShapleyFrankensteinScoresBuilder, self).__init__(background_rows_df, columns_to_explain)
        self._score_computer_from_preprocess = score_computer_from_preprocess
        self._preprocess = preprocess
        self._features_to_cols_mapping = features_to_cols_mapping
        self._preprocessed_background = self._preprocess(background_rows_df)

        frankenstein_builder, frankenstein_to_compare_builder = build_shapley_frankenstein_masks(
            self._total_num_cols, self._columns_to_explain_indices, self._background_size, self.permutations)

        self._frankenstein_features_builder = self._project_frankenstein_to_features_space(frankenstein_builder)
        self._frankenstein_to_compare_features_builder = self._project_frankenstein_to_features_space(frankenstein_to_compare_builder)

    def _project_frankenstein_to_features_space(self, frankenstein_builder):
        """""
        Projecting a frankenstein array builder to the feature space, i.e. going from a
        (num_cols_to_explain, background_size, total_num_columns) mask to a
        (num_cols_to_explain, background_size, preprocessed_num_columns) mask

        :type frankenstein_builder: ShapleyFrankensteinArrayBuilder
        :rtype: ShapleyFrankensteinArrayBuilder
        """
        frankenstein_features_mask = np.take_along_axis(frankenstein_builder.mask,
                                                        self._features_to_cols_mapping[np.newaxis, np.newaxis, :],
                                                        axis=2)
        return ShapleyFrankensteinArrayBuilder(frankenstein_features_mask)

    def get_generated_number_cells_per_frankenstein(self):
        logger.info("%s cells will be generated by %s for each row" % (self._frankenstein_features_builder.get_size(),
                                                                       self._frankenstein_features_builder))
        return self._frankenstein_features_builder.get_size()

    def _get_frankenstein_score(self, df_preprocessed, df_score, frankenstein_builder):
        matching_indices_in_arr, frankenstein_arr = frankenstein_builder.build_frankenstein_array(
            df_preprocessed, self._preprocessed_background)
        logger.info("Built frankenstein of shape {}".format(frankenstein_arr.shape))
        return self._score_computer_from_preprocess(frankenstein_arr, df_score, matching_indices_in_arr)

    def build_scores(self, df):
        df_preprocessed = self._preprocess(df)
        df_score = self._score_computer_from_preprocess(df_preprocessed)
        frankenstein_score = self._get_frankenstein_score(
            df_preprocessed, df_score, self._frankenstein_features_builder)
        frankenstein_to_compare_score = self._get_frankenstein_score(
            df_preprocessed, df_score, self._frankenstein_to_compare_features_builder)
        return frankenstein_score, frankenstein_to_compare_score


class ShapleyIndividualExplanationsExtractor(ShapleyExplanationsExtractor):

    def extract_explanations(self, observations_df, baseline_score, swapped_columns_score):
        """
        :type observations_df: pd.DataFrame
        :type baseline_score: ScoreToExplain
        :type swapped_columns_score: ScoreToExplain
        :rtype: SimpleExplanationResult
        """
        # Compute marginal contribution of having some column in predictions but not in predictions_to_compare
        # Shape: (len(columns), background_size, len(observations_df))
        diff = (baseline_score.score - swapped_columns_score.score).reshape(
            self._background_size, len(self._columns_to_explain), -1, order="F")
        # Average over the background_size to get Shapley values for each column of interest in each row
        # Shape: (len(observations_df), len(columns))

        # With overrides declined outcome the model can output nan predictions, therefore we need to ensure
        # that the mean is computed only on the valid predictions.
        reduced_diff = np.nanmean(diff, axis=0).T

        # columns_indices corresponds to the columns Shapley values were computed on
        shapley_df = pd.DataFrame(
            data=reduced_diff, index=observations_df.index, columns=self._columns_to_explain, dtype=np.float64)
        return SimpleExplanationResult(shapley_df)


class ShapleyGlobalExplanationsExtractor(ShapleyIndividualExplanationsExtractor):

    def extract_explanations(self, observations_df, baseline_score, swapped_columns_score):
        """
         :rtype: GlobalExplanationResult
        """
        explanation_result = GlobalExplanationResult()
        explanations = super(ShapleyGlobalExplanationsExtractor, self).extract_explanations(
            observations_df, baseline_score, swapped_columns_score)
        explanation_result.add_explanations("unique", explanations)
        return explanation_result


class MulticlassShapleyGlobalExplanationsExtractor(ShapleyIndividualExplanationsExtractor):

    def __init__(self, classes):
        super(MulticlassShapleyGlobalExplanationsExtractor, self).__init__()
        self._classes = classes

    def extract_explanations(self, observations_df, baseline_score, swapped_columns_score):
        """
        :type observations_df: pd.DataFrame
        :type baseline_score: MulticlassScoreToExplain
        :type swapped_columns_score: MulticlassScoreToExplain
        :rtype: GlobalExplanationResult
        """
        explanation_result = GlobalExplanationResult()
        for class_index, klass in enumerate(self._classes):
            class_explanations = super(MulticlassShapleyGlobalExplanationsExtractor, self).extract_explanations(
                observations_df,
                OneDimensionScoreToExplain(baseline_score.per_class_score[:, class_index]),
                OneDimensionScoreToExplain(swapped_columns_score.per_class_score[:, class_index])
            )
            explanation_result.add_explanations(klass, class_explanations)
        return explanation_result


def get_permutations(n_permutations, n_columns):
    """
    :param int n_permutations: Number of permutations to generate
    :param int n_columns: Number of columns
    :return: Array of shape (n_permutations, n_columns), where each row contains random numbers from 0 - n_columns
    """
    random_state = np.random.RandomState(RANDOM_SEED)
    return np.array([np.argsort(random_state.uniform(0, 1, n_columns)) for _ in range(n_permutations)])


def build_shapley_frankenstein_masks(total_num_columns, col_indices_to_explain, background_size, permutations):
    """
    Builds masks required to build Shapley frankenstein arrays.

    The Shapley algorithm relies on:
     * Randomly permuting values of some columns of the data with values taken from _background_ data
       in 2 similar but different ways,
     * Then getting the prediction of those newly created rows
     * Deriving the explanations by comparing the predictions of the 2 ways to permute
    More details in the original paper: https://arxiv.org/abs/1705.07874

    In order to efficiently compute those new rows and their prediction, we build _frankenstein_ arrays that are
    concatenations of all those rows for all the columns to explain, and then run the prediction only once on the
    frankenstein. To do so, we first build masks that will tell us, for each cell of the frankenstein, whether it needs
    to be taken from the original data, or permuted with some cell from the background.

    In order to better rely on the magic of numpy broadcasting feature, we do not build 2D masks, but 3D masks such
    that the values in the 2D mask [i, :, :] represent the mask for column to explain number i.

    The pseudo algorithm for building the frankenstein, if it was done iteratively, would be as follows:
    frankenstein = []
    frankenstein_compare = []
    for each row in observation_df:            => [200, 'Mr Smith', 'male', 22.0, 7.25]
      for each col in cols_to_explain:         => "Sex"
        for i in range(background_size):       => [190,   'j. Doe', 'male', 38.0,   70]
          permutations = get_permutation()     => [3, 0, 2, 1, 4]
          col_index = get_col_index()          => 2
          # swap all cols in permutations prior to col_index
          c_to_swap = get_c_to_swap()          => [3, 0]
          n_row = do_swap(row, c_to_swap,      => [190, 'Mr Smith', 'male', 38.0, 7.25]
                          back_rows[i])             ^                        ^
          frankenstein.add(n_row)
          # for comparison, swap all cols in permutations including col_index
          c_to_swap_c = get_c_to_swap_c()    => [3, 0, 2]
          n_row_c = do_swap(row, c_to_swap_c,=> [190, 'Mr Smith', 'male', 38.0, 7.25]
                            back_rows[i])         ^                  ^      ^
          frankenstein_c.add(n_row_c)


    :param int total_num_columns: number of columns in the data
    :param list[int] col_indices_to_explain: list of indices of columns in the data to explain
    :param int background_size: size of the background (e.g. number of Monte Carlo steps
    :param np.ndarray permutations: permutations for randomly swapping values,
                                    must be of shape (background_size, total_num_columns)
    :rtype: (ShapleyFrankensteinArrayBuilder, ShapleyFrankensteinArrayBuilder)
    :returns Two ShapleyFrankensteinArrayBuilder with mask arrays of shape
                (num_cols_to_explain, background_size, total_num_cols)
             where the [i, :, :] array represents the mask to apply for column to explain #i:
              * it's True when the value needs to permuted, i.e. taken from the background
              * else, then it is taken from the original data
    """
    assert permutations.shape == (background_size, total_num_columns)
    num_columns_to_explain = len(col_indices_to_explain)
    fk_permutations_mask = np.zeros((num_columns_to_explain, background_size, total_num_columns), dtype=bool)
    fk_compare_permutations_mask = fk_permutations_mask.copy()

    # For loop over column length done once, which should be negligible compared to the other dimensions, so ok
    for col_index, col_index_in_all_cols in enumerate(col_indices_to_explain):
        # Example for:
        #  total_num_columns = 4
        #  background_size = 7
        #  col_index_in_all_cols = 2
        #  permutations = array([[1, 0, 3, 2],
        #                        [1, 3, 0, 2],
        #                        [2, 1, 3, 0],
        #                        [3, 0, 1, 2],
        #                        [2, 0, 3, 1],
        #                        [1, 0, 2, 3],
        #                        [1, 2, 3, 0]])
        #
        # permutations == col_index_in_all_cols
        # => array([[False, False, False,  True],
        #           [False, False, False,  True],
        #           [ True, False, False, False],
        #           [False, False, False,  True],
        #           [ True, False, False, False],
        #           [False, False,  True, False],
        #           [False,  True, False, False]])
        #
        # # Only 1 True per row, we expand the True to the end of the row
        # (permutations == col_index_in_all_cols).cumsum(axis=1)
        # => array([[0, 0, 0, 1],
        #           [0, 0, 0, 1],
        #           [1, 1, 1, 1],
        #           [0, 0, 0, 1],
        #           [1, 1, 1, 1],
        #           [0, 0, 1, 1],
        #           [0, 1, 1, 1]])
        #
        # # Getting the logical opposite puts True where, in the permutations we need to swap
        # np.logical_not((permutations == col_index_in_all_cols).cumsum(axis=1))
        # => array([[ True,  True,  True, False],
        #           [ True,  True,  True, False],
        #           [False, False, False, False],
        #           [ True,  True,  True, False],
        #           [False, False, False, False],
        #           [ True,  True, False, False],
        #           [ True, False, False, False]])
        #
        #
        # # Retrieving the indices in the permutations
        # indices_to_permute_col = np.where(np.logical_not((permutations == col_index_in_all_cols).cumsum(axis=1)))
        # => (array([0, 0, 0, 1, 1, 1, 3, 3, 3, 5, 5, 6]),
        #     array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0]))
        indices_to_permute_col = np.where(np.logical_not((permutations == col_index_in_all_cols).cumsum(axis=1)))
        # # Now that we have the indices in the permutations, we can project them on the real mask
        # per_col_mask[indices_to_permute_col[0], permutations[indices_to_permute_col]] = True
        # per_col_mask
        # => array([[ True,  True, False,  True],
        #           [ True,  True, False,  True],
        #           [False, False, False, False],
        #           [ True,  True, False,  True],
        #           [False, False, False, False],
        #           [ True,  True, False, False],
        #           [False,  True, False, False]])
        fk_permutations_mask[col_index, :, :][indices_to_permute_col[0], permutations[indices_to_permute_col]] = True

        # Build 2nd frankenstein mask filled with the rows of 1st frankenstein
        # except column being explained is also permuted with background rows
        fk_compare_permutations_mask[col_index, :, :] = fk_permutations_mask[col_index, :, :]
        fk_compare_permutations_mask[col_index, :, :][:, col_index_in_all_cols] = True

    return (ShapleyFrankensteinArrayBuilder(fk_permutations_mask),
            ShapleyFrankensteinArrayBuilder(fk_compare_permutations_mask))


class ShapleyFrankensteinArrayBuilder(object):

    def __init__(self, mask):
        """
        :param np.ndarray mask: Frankenstein mask of size (num_cols_to_explain, background_size, total_num_columns)
        """
        assert mask.ndim == 3, "Wrong shape for Frankenstein mask"
        self.mask = mask
        self._num_cols_to_explain, self._background_size, self._total_num_columns = mask.shape

    def get_size(self):
        return self.mask.size

    def __str__(self):
        return ("FrankensteinBuilder(total_num_columns={}, " 
                "background_size={}, num_cols_to_explain={})".format(self._total_num_columns,
                                                                     self._background_size, self._num_cols_to_explain))

    def build_frankenstein_array(self, data, background_rows):
        """
        Build frankenstein array and return it, along with the matching indices in the data it was built from
        :type data: np.ndarray
        :type background_rows: np.ndarray
        :return two numpy arrays:
                    * the matching indices in the original array
                    * the frankenstein array
        :rtype: (np.ndarray, np.ndarray)
        """
        assert data.shape[1] == self._total_num_columns, "Wrong number of columns"
        assert background_rows.shape[1] == self._total_num_columns, "Wrong number of columns"
        assert background_rows.shape[0] == self._background_size, "Wrong number of rows in background"
        matching_indices_in_array = np.repeat(np.arange(data.shape[0]),
                                              self._background_size * self._num_cols_to_explain)
        return (matching_indices_in_array,
                np.where(self.mask[np.newaxis, :, :, :],
                         background_rows[np.newaxis, np.newaxis, :, :],
                         data[:, np.newaxis, np.newaxis, :])
                  .reshape(-1, self._total_num_columns))
