"""
Reducers try to make counterfactuals simpler
"""

from abc import ABCMeta
from abc import abstractmethod

import numpy as np
import pandas as pd
from six import add_metaclass


@add_metaclass(ABCMeta)
class BaseReducer(object):
    def __init__(self, feature_domains):
        self.feature_domains = feature_domains

    @abstractmethod
    def reduce(self, x_ref, y_ref, counterfactuals):
        pass

    def check_is_valid(self, samples):
        """
        :param samples: pd.DataFrame

        Check if samples array satisfy feature_domains constraints
        :return: a boolean dataframe with the same shape as samples
        """
        status = pd.DataFrame(True, index=samples.index, columns=samples.columns)
        if self.feature_domains is None:
            return status
        for feature_domain in self.feature_domains:
            feature_name = feature_domain.feature_name
            status[feature_name] = feature_domain.check_validity(samples[feature_name].values)
        return status


class FeatureImportanceReducer(BaseReducer):
    """
    Reducer that tries to reset as many features as possible to corresponding values
    of the reference, without considering their distance to the reference
    """
    def __init__(self, target, model, feature_domains):
        super(FeatureImportanceReducer, self).__init__(feature_domains)
        self.target = target
        self.model = model

    @staticmethod
    def _build_fk_from_fk_mask(ref, counterfactuals, fk_mask):
        reset_proposals_fk = np.where(fk_mask,
                                      ref.values[0][np.newaxis, np.newaxis, :],
                                      counterfactuals.values[np.newaxis, :, :])
        reset_proposals_fk = reset_proposals_fk.reshape(-1, ref.shape[1])
        return pd.DataFrame(reset_proposals_fk, columns=counterfactuals.columns).astype(counterfactuals.dtypes)

    @staticmethod
    def _build_importance_fk(ref, counterfactuals, resettable_cols):
        """
        For each row i and for each resettable column j, we build a new sample by just
        replacing column j with the value of the reference for column j.
        We build a (n_rows * n_cols_to_reset, n_cols) frankenstein `importance_fk`,
        such that row `j*len(counterfactuals) + i` represents the new sample for row i
        and the j-th resettable column.
        """
        n_rows, n_cols = counterfactuals.shape
        n_cols_to_reset = len(resettable_cols)
        importance_fk_mask = np.zeros((n_cols_to_reset, n_rows, n_cols))
        importance_fk_mask[np.arange(n_cols_to_reset), :, resettable_cols] = 1
        return FeatureImportanceReducer._build_fk_from_fk_mask(ref, counterfactuals, importance_fk_mask)

    def _get_feature_importance(self, ref, y_ref, counterfactuals, resettable_cols):
        """
        We only need to compute the feature importance of the resettable columns. To compute this
        (n_rows, n_resettable_cols) feature importance array, we need (n_row * n_resettable_cols)
        predictions, i.e. build a (n_rows * n_resettable_cols, n_cols) frankenstein `importance_fk`, such
        that row `j*len(counterfactuals) + i` represents the new sample for row i and j-th resettable column.
        """
        importance_fk = self._build_importance_fk(ref, counterfactuals, resettable_cols)
        frankensteins_probas = self.model.predict_proba(importance_fk)
        n_rows = counterfactuals.shape[0]
        return self._get_feature_importance_from_fk_probas(frankensteins_probas, y_ref, n_rows, resettable_cols)

    def _get_feature_importance_from_fk_probas(self, frankensteins_probas, y_ref, n_rows, resettable_cols):
        """
        We get `feature_importance` such that feat_imp[i, j] contains the feature
        importance for row i and the j-th resettable column.
        For a given frankenstein sample that reset a specific feature, if the proba for
        the target class is low, then it means that resetting this feature had a big
        impact on the prediction.
        Therefore, we define the feature importance as `-1 * probas_of_target_class`.
        """
        if self.target is not None:
            targets = [self.target]
        else:
            targets = [i for i in range(frankensteins_probas.shape[1]) if i != y_ref]
        probas_for_target = frankensteins_probas[:, targets].max(axis=1).astype(float)
        return -probas_for_target.reshape((len(resettable_cols), n_rows)).T

    @staticmethod
    def _get_ordered_cols_to_reset(feature_importance, resettable_cols):
        # Least important features should be reset first.
        return resettable_cols[np.argsort(feature_importance, axis=1)]

    @staticmethod
    def _build_reset_proposals_fk(ref, counterfactuals, ordered_cols_to_reset):
        """
        With ordered_cols_to_reset ordered from the least important to the most important col,
        we build our combinations of replacements of the counterfactuals by the reference values.
        For each row, there is n_cols proposals:
         - no replacement
         - replacing the least important column
         - replacing the 2 least important columns
         - ...
         - replacing all the cols to reset
        We build a (n_rows * n_cols, n_cols_to_reset + 1) frankenstein `reset_proposals_fk`,
        such that row `j*len(counterfactuals) + i` represents the j-th proposal for row i
        """
        n_rows, n_cols = counterfactuals.shape
        n_cols_to_reset = ordered_cols_to_reset.shape[1]
        partially_reset_mask = np.zeros((n_cols_to_reset + 1, n_rows, n_cols))
        for col_index in range(1, n_cols_to_reset + 1):
            intermediary_mask = np.zeros((n_rows, n_cols))
            np.put_along_axis(intermediary_mask, ordered_cols_to_reset[:, range(col_index)], 1, axis=1)
            partially_reset_mask[col_index, :, :] = intermediary_mask
        return FeatureImportanceReducer._build_fk_from_fk_mask(ref, counterfactuals, partially_reset_mask)

    def _get_most_reset_counterfactuals_indices(self, y_ref, reset_proposals_fk, max_n_cols_to_reset):
        """
        To retrieve the proposal with the maximum number of replacements that is still valid, we need to find the
        highest index that is valid.
        For example, if we have 6 columns, for a given sample, we have the following valid:
              valid = [1, 1, 1, 0, 1, 0]
        this means that replacements 0, 1, 2 and 4 are valid, while replacements 3 and 5 are not
        in that case we want to get as index: 4, that would mean replacing the 4 least important columns
        To retrieve this information from the `valid` array, we can find the index of the last "1":
              new_index = (len(valid) - 1) - np.argmax(valid[::-1])
                        = (6 - 1) - 1
                        = 4
        """
        # warning: in DSS, the prediction is not always probas.argmax()
        partially_reset_preds = self.model.predict(reset_proposals_fk)

        if self.target is not None:
            validity_mask = np.isin(partially_reset_preds, self.target)
        else:
            validity_mask = ~np.isin(partially_reset_preds, y_ref)
        return max_n_cols_to_reset - np.argmax(validity_mask.reshape(max_n_cols_to_reset + 1, -1)[::-1, :], axis=0)

    def _reduce_l0_norm(self, ref, y_ref, counterfactuals):
        """
        Trying to replace produced counterfactuals with reference values as much as possible to see whether we can have
        counterfactuals that are closer to the reference

        Because trying out all combinations of replacement would be intractable, we use a quicker algorithm:
        * First, for each counterfactual sample, we compute a "feature importance" order, so that we know which feature
          has the most impact on the prediction for this sample
        * Then, we replace the least impactful column by the reference, and see if it still produces a counterfactual,
          then we replace the 2 least impactful columns by the reference, and so on, until replacing n_resettable_cols
        * For each counterfactual sample, we take the replacement proposal with the most replacements that still
          produces a counterfactual

        In order to optimize the speed of the function, we try as much as possible to leverage vectorized operations
        from numpy and pandas, and we try to make as few calls to `predict` as possible. That leads us to build
        `frankensteins` (with `_fk` in the name), big arrays where we try out several combinations.

        :type ref: pd.DataFrame
        :type y_ref: int
        :type counterfactuals: pd.DataFrame
        :return: Modified counterfactuals with less difference compared to `ref`, in order to produce results close to
                 the reference
        """
        # The values of a given column cannot be reset if the corresponding value in the reference doesn't
        # respect the constraints of feature_domains.
        # Here, we find the list of columns that can potentially be reset.
        resettable_cols_mask = self.check_is_valid(ref).iloc[0]
        resettable_cols = resettable_cols_mask.values.nonzero()[0]
        if len(resettable_cols) == 0:
            return counterfactuals

        feature_importance = self._get_feature_importance(ref, y_ref, counterfactuals, resettable_cols)

        ordered_cols_to_reset = self._get_ordered_cols_to_reset(feature_importance, resettable_cols)

        # With these cols to reset ordered from the least important to the most important, we can build all our
        # combinations of replacements of the counterfactuals by the reference values.
        reset_proposals_fk = self._build_reset_proposals_fk(ref, counterfactuals, ordered_cols_to_reset)

        # Retrieve the proposals with the maximum number of replacements that are still valid counterfactuals
        most_reset_counterfactuals_indices = self._get_most_reset_counterfactuals_indices(y_ref,
                                                                                          reset_proposals_fk,
                                                                                          len(resettable_cols))

        # Finally, we can use the computed indices to retrieve the final counterfactuals from the reset_proposals_fk
        n_rows, n_cols = counterfactuals.shape
        final_counterfactuals = np.take_along_axis(
            reset_proposals_fk.values.reshape((len(resettable_cols) + 1, n_rows, n_cols)),
            most_reset_counterfactuals_indices[np.newaxis, :, np.newaxis],
            axis=0
        ).reshape((-1, n_cols))
        return pd.DataFrame(final_counterfactuals, columns=counterfactuals.columns).astype(counterfactuals.dtypes)

    def reduce(self, x_ref, y_ref, counterfactuals):
        return self._reduce_l0_norm(x_ref, y_ref, counterfactuals)
