from __future__ import division

from abc import ABCMeta
from abc import abstractmethod
from functools import partial
from warnings import catch_warnings
from warnings import simplefilter

import numpy as np
import pandas as pd
from six import add_metaclass
from scipy.optimize import linear_sum_assignment

from dataiku.doctor.exploration.emu.algorithms.utils import init_df_from
from dataiku.doctor.exploration.emu.reducers import FeatureImportanceReducer
from dataiku.doctor.utils.skcompat import get_kmeans_estimator


class CounterfactualsGenerator(object):

    def __init__(self, sample_generator, model, target):
        self.sample_generator = sample_generator
        self.model = model
        self.target = target

    def generate_counterfactuals(self, x_ref, y_ref, min_radius, max_radius):
        x_cf_candidates = self.sample_generator.generate_samples(x_ref, y_ref, min_radius, max_radius)
        x_cf = self.drop_samples_with_invalid_preds(x_cf_candidates, y_ref)
        return x_cf

    @staticmethod
    def merge_counterfactuals(x_cf, x_cf_new):
        if x_cf_new.size > 0:
            x_cf = pd.concat([x_cf, x_cf_new])
        x_cf.drop_duplicates(inplace=True)
        return x_cf

    def drop_samples_with_invalid_preds(self, df, y_ref):
        if self.target is None:
            is_counterfactual = y_ref != self.model.predict(df)
        else:
            is_counterfactual = np.in1d(self.model.predict(df), self.target)
        return df.iloc[is_counterfactual]


##########################################################################
# SAMPLE GENERATORS
##########################################################################

@add_metaclass(ABCMeta)
class SampleGenerator(object):
    """
    SampleGenerators are used to generate points using the feature handlers and given radii.
    """
    def __init__(self, handlers, batch_size=100):
        self.handlers = handlers
        self.batch_size = batch_size

    @abstractmethod
    def generate_samples(self, x_ref, y_ref, min_radius, max_radius):
        pass


class HyperSphereSampleGenerator(SampleGenerator):
    def generate_samples(self, x_ref, y_ref, min_radius, max_radius):
        radii_gen = np.random.uniform(min_radius, max_radius, size=self.batch_size)

        # Generate sphere
        # First, we just generate a random weight for each feature
        cdf_gen = np.random.random((self.batch_size, x_ref.shape[1]))
        # Second, we need for each row that the sum of squares equals the radius squared
        cdf_gen = (cdf_gen ** 2 * (radii_gen ** 2 / (cdf_gen ** 2).sum(axis=1))[:, None]) ** .5
        # Randomize the signs
        cdf_gen = cdf_gen * np.random.choice([-1, 1], size=cdf_gen.shape)

        x_gen = init_df_from(x_ref)
        for i, (feature_name, handler) in enumerate(self.handlers.items()):
            x_gen[feature_name] = handler.generate_cf_values(x_ref.iloc[0][feature_name], y_ref, cdf_gen[:, i])
        return x_gen


##########################################################################
# SEARCH STRATEGIES
##########################################################################


@add_metaclass(ABCMeta)
class BaseSphereSearch(object):
    def __init__(self, counterfactuals_generator, max_n_iter=6):
        """
        :param CounterfactualsGenerator counterfactuals_generator: to generate counterfactuals
        :param int max_n_iter: max number of steps to find min_radius and max_radius
        """
        self.counterfactuals_generator = counterfactuals_generator
        self.max_n_iter = max_n_iter

    @abstractmethod
    def find_min_max_radius(self, x_ref, y_ref, return_cf=True):
        """
        Find the minimal values for (min_radius, max_radius) such that we can find some
        counterfactuals for radii between min_radius and max_radius.
        """


class ActiveSphereSearch(BaseSphereSearch):
    def find_min_max_radius(self, x_ref, y_ref, return_cf=True):
        """
        Abstract: Find approximation of the minimal radius as a (min_radius, max_radius)
            interval. The minimal radius is the smallest radius for which we can generate
            counterfactuals.

        Context: When generating counterfactuals, we need to provide a min_radius and a
            max_radius that are between 0 and 1. When the radius is close to 0, the
            generator will generate points that are close to the reference. When it's
            close to 1, the generator will sample points far from the reference.

            The goal of this function is to find the minimal radius that actually
            generates counterfactuals, so that we can generate counterfactuals that are
            as close to the reference as possible.

        Note: We use a dichotomic search such that after `i` iterations we have:
            (max_radius - min_radius) == 1 / (2**i)
        """
        if return_cf:
            x_cf = init_df_from(x_ref)

        min_radius = 0.0
        max_radius = 1.0  # radii are normalized, so we set initial max radius to 1.0
        width = max_radius - min_radius
        for _ in range(self.max_n_iter):  # 6 iterations means we are down to 1/64 = 0.016 as radius width
            width /= 2
            # Try with lower max_radius to see whether we still get counterfactuals
            x_cf_new = self.counterfactuals_generator.generate_counterfactuals(x_ref, y_ref, min_radius, max_radius - width)
            if x_cf_new.size > 0:  # if yes, actually lower the max_radius
                max_radius -= width
                if return_cf:
                    x_cf = self.counterfactuals_generator.merge_counterfactuals(x_cf, x_cf_new)
            else:  # else increase min_radius
                min_radius += width

        if return_cf:
            return min_radius, max_radius, x_cf
        return min_radius, max_radius


##########################################################################
# GROWING SPHERES ALGORITHMS
##########################################################################


@add_metaclass(ABCMeta)
class BaseGrowingSphere(object):
    """
    Growing-Spheres algorithm to generate counterfactuals for classification models.
    """
    def __init__(self,
                 feature_domains,
                 handlers,
                 model,
                 measure_distance,
                 target=None,
                 with_clustering=True,
                 max_n_iter=40):
        """
        :param FeatureDomains feature_domains: constraints
        :param handlers: to sample values for individual features
        :type handlers: dict[str, (BaseNumericalCFFeatureHandler | BaseCategoryCFFeatureHandler | FrozenFeatureHandler)]
        :param sklearn.base.BaseEstimator model: a trained model
        :param function measure_distance: function that returns the distance between two samples
        :param target: class for the counterfactuals, if None: any class but the reference's is considered as target
        :param bool with_clustering: flag to activate clustering step
        :param int max_n_iter: maximum number of iterations when trying to generate counterfactuals
        """
        self.feature_domains = feature_domains
        self.model = model
        self.measure_distance = measure_distance
        self.target = target
        self.with_clustering = with_clustering
        self.max_n_iter = max_n_iter
        sample_generator = self._get_sample_generator(handlers)
        self.counterfactuals_generator = CounterfactualsGenerator(sample_generator, model, target)

    @abstractmethod
    def _get_search_strategy(self):
        """
        :rtype: BaseSphereSearch
        """

    @abstractmethod
    def _get_reducer(self):
        """
        :rtype: BaseReducer
        """

    @abstractmethod
    def _get_sample_generator(self, handlers):
        """
        :rtype: SampleGenerator
        """

    def _generate_counterfactuals_from_radius(self, x_ref, y_ref, min_radius, max_radius, n_points):
        x_cf = init_df_from(x_ref)
        iteration = 0
        while x_cf.shape[0] < n_points and iteration < self.max_n_iter:
            x_cf_new = self.counterfactuals_generator.generate_counterfactuals(x_ref, y_ref, min_radius, max_radius)
            x_cf = self.counterfactuals_generator.merge_counterfactuals(x_cf, x_cf_new)
            iteration += 1
        return x_cf

    def _filter_with_clustering(self, x_cf, n_points):
        preprocessed_x_cf = self.model.preprocess(x_cf)
        with catch_warnings():
            simplefilter("ignore")  # we don't care if some clusters are duplicates
            kmeans = get_kmeans_estimator(n_clusters=n_points, random_state=0).fit(preprocessed_x_cf)
        distances_to_centroids = kmeans.transform(preprocessed_x_cf)

        # Sometimes, one sample can be the closest to two centroids. In that
        # case, we want to take the second closest one, etc.
        # linear_sum_assignemnt solves this problem.
        indices_closest_to_centroids = linear_sum_assignment(distances_to_centroids)[0]
        return x_cf.iloc[indices_closest_to_centroids]

    def _filter_n_closest_points(self, x_cf, x_ref, n_points):
        distances = np.vectorize(partial(self.measure_distance, x_ref.iloc[[0]]), signature='(x)->()')(x_cf)
        return x_cf.iloc[np.argsort(distances)[:n_points]]

    def generate_counterfactuals(self, x_ref, n_points):
        """
        Generate Growing-Spheres counterfactuals for an input sample.

        :param x_ref: original sample - reference point (df with one single row)
        :param int n_points: desired number of counterfactuals (max number of counterfactuals to return)
        """
        y_ref = self.model.predict(x_ref)[0].astype(int)
        search_strategy = self._get_search_strategy()
        min_radius, max_radius, x_cf = search_strategy.find_min_max_radius(x_ref, y_ref)
        x_cf_new = self._generate_counterfactuals_from_radius(x_ref, y_ref, min_radius, max_radius, n_points)
        x_cf = self.counterfactuals_generator.merge_counterfactuals(x_cf, x_cf_new)

        if x_cf.size == 0:
            return x_cf

        reducer = self._get_reducer()
        x_cf = reducer.reduce(x_ref, y_ref, x_cf)
        x_cf.drop_duplicates(inplace=True)

        if x_cf.shape[0] > n_points > 1 and self.with_clustering:
            x_cf = self._filter_with_clustering(x_cf, n_points)

        return self._filter_n_closest_points(x_cf, x_ref, n_points)


class ActiveSphere(BaseGrowingSphere):
    """
    Active-Spheres algorithm to generate counterfactuals for classification models.
    """
    def _get_reducer(self):
        return FeatureImportanceReducer(self.target, self.model, self.feature_domains)

    def _get_search_strategy(self):
        return ActiveSphereSearch(self.counterfactuals_generator)

    def _get_sample_generator(self, handlers):
        return HyperSphereSampleGenerator(handlers)
