from abc import ABCMeta
from enum import Enum
import logging

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku.doctor.exploration.emu.algorithms import DiverseEvolutionaryOutcomeOptimizer
from dataiku.doctor.exploration.emu.algorithms import EfficientEvolutionaryOutcomeOptimizer
from dataiku.doctor.exploration.emu.generators import BaseGenerator
from dataiku.doctor.exploration.emu.handlers import NumericalNormalOOHandler
from dataiku.doctor.exploration.emu.handlers import CategoryDistributionOOHandler

logger = logging.getLogger(__name__)


class SpecialTarget(Enum):
    MIN = "MIN"
    MAX = "MAX"


@add_metaclass(ABCMeta)
class BaseOOGenerator(BaseGenerator):
    """
    Base class for outcome optimization generators.
    """
    def __init__(self, model, target=None):
        """
        Create a Feature Sampling counterfactual generator.

        :param sklearn.base.BaseEstimator model: A trained model.
        :param SpecialTarget or int or float target: the objective
        """
        super(BaseOOGenerator, self).__init__(model)
        self.target = target
        self.non_null_samples = None

    def fit(self, X, y, feature_domains):
        self.non_null_samples = X.loc[X.notna().all(axis=1)]
        super(BaseOOGenerator, self).fit(X, y, feature_domains)

    def _get_loss_function(self):
        """
        Get the loss function that uses the model and depends on the objective.

        :return: function that evaluates a point and returns a numeric value
        """
        if self.target == SpecialTarget.MIN:
            return self.model.predict
        if self.target == SpecialTarget.MAX:
            return lambda x: -self.model.predict(x)
        if isinstance(self.target, (int, float)):
            return lambda x: np.abs(self.target - self.model.predict(x))
        raise ValueError("Target should either be MIN, MAX, or a numeric value")

    def optimize(self, reference, n=50):
        """
        Use the points given in `history` to generate new points that are closer to `target`.

        :param pd.DataFrame reference: reference point (df with one single row)
        :param n: numer of optima to find
        """
        # The reference cannot contain missing values, so we concatenate it just in case `non_null_samples` is empty.
        history = pd.concat([reference, self.non_null_samples]).sample(n=n, replace=True)
        return self.algorithm.find_optima(history, pop_size=n)


@add_metaclass(ABCMeta)
class EvolutionaryOOGenerator(BaseOOGenerator):
    """
    Outcome optimization generator, based on an evolutionary strategy.
    """
    def _get_new_categorical_handler(self, feature_domain):
        """
        :type feature_domain: CategoricalFeatureDomain
        :rtype: BaseCategoryFeatureHandler
        """
        return CategoryDistributionOOHandler(feature_domain=feature_domain)

    def _get_new_numerical_handler(self, feature_domain):
        """
        :type feature_domain: NumericalFeatureDomain
        :rtype: BaseNumericalFeatureHandler
        """
        return NumericalNormalOOHandler(feature_domain=feature_domain,
                                        distance_name=self.distance_name)


class DiverseEvolutionaryOOGenerator(EvolutionaryOOGenerator):
    """
    Focused on the diversity of the results.
    """
    def _get_new_algorithm(self):
        """
        :return: the algorithm that will generate the points
        :rtype: BaseGrowingSphere or EvolutionaryStrategy
        """
        return DiverseEvolutionaryOutcomeOptimizer(handlers=self.handlers,
                                                   loss=self._get_loss_function(),
                                                   preprocess=self.model.preprocess)


class EfficientEvolutionaryOOGenerator(EvolutionaryOOGenerator):
    """
    Focused on the computation time and the score of the best sample.
    """
    def _get_new_algorithm(self):
        """
        :return: the algorithm that will generate the points
        :rtype: BaseGrowingSphere or EvolutionaryStrategy
        """
        return EfficientEvolutionaryOutcomeOptimizer(handlers=self.handlers,
                                                     loss=self._get_loss_function(),
                                                     preprocess=self.model.preprocess)
