import numpy as np
import sklearn
from scipy.stats import gmean

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.utils.skcompat.utils import _replace_value
from dataiku.doctor.utils.skcompat.utils import _swap_variables

if package_is_at_least(sklearn, '0.23'):
    from sklearn.dummy import DummyClassifier, DummyRegressor


    # We're going to recreate sklearn objects compatible with sklearn 0.24+ from pickled
    # 0.20 objects. That mean we have access to the `state` dictionary of the old objects,
    # and from there deduce the correct values for the new classes.
    # Luckily, sklearn mostly shuffled values but preserved its logic, so it's mostly a matter of
    # reading both versions source code and map things correctly.

    # A quick guide on how do such a mapping :
    # 1. Notice than for example, you can't unpickle an objet of type "BinomialDeviance", because it lacks the attribute "LogOddsEstimator"
    # 2. Check in the O.20.4 scikit code what this is : https://github.com/scikit-learn/scikit-learn/blob/0.20.4/sklearn/ensemble/gradient_boosting.py#L781
    # 3. Check in the 0.24.2 scikit code what it now is : https://github.com/scikit-learn/scikit-learn/blob/1.0.2/sklearn/ensemble/_gb_losses.py#L658
    # 4. We understand that instead of unpickling towards a LogOddsEstimator, we must provide an instance of DummyClassifier
    # 5. We plug into the __setstate__ of our class to build such a classifier. Through careful reading of the code of DummyClassifier, and trial and error, we
    # understand what are the attributes of this class, what they mean, and how to deduce them from what we have.
    # For example, DummyClassifier needs proba_0 and proba_1 and we find that we have in the state 'prior', computed here : https://github.com/scikit-learn/scikit-learn/blob/0.20.4/sklearn/ensemble/gradient_boosting.py#L187
    # From this formula, we deduce how to get proba_1 (aka "pos" in the refered code).
    # For the other fields, the simplest way is to put a breakpoint in scikit code and understand what are the default values from there.

    # The two estimators we support for binary classification
    class LogOddsEstimatorDummyClassifier(DummyClassifier, object):
        # Used to store the prior probability, now stores proba_0 and proba_1
        def __setstate__(self, state):
            # prior = scale * log(proba(1)/1-proba(1)), we reverse the equality, it becomes proba_1 = exp(1/scale * prior)/(1 + exp(1/scale * prior))
            # And of course proba_0 = 1 - proba_1
            # Here, scale = 1
            proba_1 = np.exp(state['prior']) / (1 + np.exp(state['prior']))
            self.classes_ = ['0', '1']
            self.n_classes_ = 2
            self.class_prior_ = [1 - proba_1, proba_1]
            self.n_outputs_ = 1
            self._strategy = 'prior'
            self.random_state = None
            self.constant = None
            self.sparse_output_ = False


    class ScaledLogOddsEstimatorDummyClassifier(DummyClassifier, object):

        def __setstate__(self, state):
            # Here scale = 1/2 , we apply the same equation
            proba_1 = np.exp(2 * state['prior']) / (1 + np.exp(2 * state['prior']))
            self.classes_ = ['0', '1']
            self.n_classes_ = 2
            self.class_prior_ = [1 - proba_1, proba_1]
            self.n_outputs_ = 1
            self._strategy = 'prior'
            self.random_state = None
            self.constant = None
            self.sparse_output_ = False


    # The only estimator we support for multiclass
    class PriorProbabilityEstimatorDummyClassifier(DummyClassifier, object):
        # Nothing special, we directly use the priors
        # We never use fancy sklearn features like multioutput, so everything else is dummy default values
        def __setstate__(self, state):
            self.classes_ = []
            self.n_classes_ = len(state['priors'])
            self.class_prior_ = []
            for i in range(len(state['priors'])):
                self.classes_.append(i)
                self.class_prior_.append(state['priors'][i])
            self.n_outputs_ = 1
            self.strategy = 'prior'
            self._strategy = 'prior'
            self.random_state = None
            self.constant = None
            self.sparse_output_ = False


    # The two estimators we support for regression (which are used by 3 loss functions afterwards)
    class MeanEstimatorDummyRegressor(DummyRegressor, object):
        def __setstate__(self, state):
            self.strategy = 'mean'
            self._strategy = 'mean'
            self.constant_ = np.reshape(state['mean'], (1, -1))
            self.constant = self.constant_
            self.quantile = None
            self.n_outputs_ = 1


    class QuantileEstimatorDummyRegressor(DummyRegressor, object):
        def __setstate__(self, state):
            self.strategy = 'quantile'
            self._strategy = 'quantile'
            self.constant_ = np.reshape(state['quantile'], (1, -1))
            self.constant = self.constant_
            self.quantile = .5
            self.n_outputs_ = 1


if package_is_at_least(sklearn, '0.23') and not package_is_at_least(sklearn, '1.4'):

    from sklearn.ensemble._gb_losses import MultinomialDeviance


    class Sk020MultinomialDeviance(MultinomialDeviance, object):
        '''
        Overrides MultinomialDeviance::get_init_raw_predictions for models going from 0.20 to 0.23+ sklearn,
        because there have been a bug fix in MultinomialDeviance, but we want to preserve the buggy behaviour
        for models migrated from sklearn 0.20 to preserve identical scoring after unpickling.
        '''

        def get_init_raw_predictions(self, X, estimator):
            return estimator.predict_proba(X)


def get_gbt_regression_baseline(gbt):
    if package_is_at_least(sklearn, "0.21"):
        # In newer versions of sklearn, we can use DummyRegressor's constant_`
        # attribute that will automatically contain `mean` or `quantile`
        # depending on the chosen strategy.
        return float(gbt.init_.constant_)
    else:
        # In previous versions of sklearn, depending on the type of the init_
        # regressor, we must use either the `mean` or the `quantile` attribute.
        return gbt.init_.mean if "mean" in dir(gbt.init_) else gbt.init_.quantile


def get_gbt_binary_classification_baseline(gbt):
    if package_is_at_least(sklearn, "0.21"):
        # The initial default prediction for binary classification is:
        #   - log odds ratio, if the loss is "deviance"
        #   -  0.5 * log odds ratio, if the loss is "exponential"
        # Since sklearn 0.21, to compute the log odds ratio, we need to use
        # the class_prior_ attribute that gives the class distribution.
        log_odds_ratio = _get_log_odds_ratio(gbt.init_.class_prior_)
        if gbt.loss == "exponential":
            # The minimizer of the exponential loss is .5 * log odds ratio.
            return [log_odds_ratio * .5]
        # The minimizer of the binomial deviance loss is log odds ratio.
        return [log_odds_ratio]
    else:
        # In sklearn < 0.21, the initial default prediction for binary
        # classification can be retrieved directly thanks to the `prior`
        # attribute. `prior` will return:
        #   - log odds ratio, if the loss is "deviance"
        #   - 0.5 * log odds ratio, if the loss is "exponential"
        return [gbt.init_.prior]


def get_gbt_multiclass_classification_baseline(gbt):
    if package_is_at_least(sklearn, "1.4"):
        # Since sklearn 1.4, the initial default prediction for multiclass
        # classification is log of the weighted priors minus the log of their
        # geometric mean.
        gm = gmean(gbt.init_.class_prior_)
        return list(np.log(gbt.init_.class_prior_/gm))
    elif package_is_at_least(sklearn, "0.21"):
        # Since sklearn 0.21 and until 1.4, the initial default prediction
        # for multiclass classification is the log of the weighted priors.
        return list(np.log(gbt.init_.class_prior_))
    else:
        # In sklearn < 0.21, the initial default prediction for multiclass
        # classification is simply the weighted priors.
        return list(gbt.init_.priors)


def _get_log_odds_ratio(priors):
    # To compute the log odds ratio, we need to use the weighted class
    # distribution. ie. the weighted prior probabilities to find each class.
    # Formula: The log odds ratio is defined as: log(p / (1-p))
    #          With:   p = prior probability for the class to be 1
    #          =>      p = gbt.init_.class_prior_[1]
    #          and (1-p) = gbt.init_.class_prior_[0]
    return np.log(priors[1] / priors[0])


if package_is_at_least(sklearn, '0.22'):
    from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
else:
    from sklearn.ensemble.gradient_boosting import GradientBoostingRegressor, GradientBoostingClassifier


def update_gradient_boosting_model_state(state):
    """
    Handles renaming of attributes:
    - n_features_ to n_features_in_ in scikit 1
    - loss_ to _loss in scikit 1.1
    - lad loss name to absolute_error loss name in scikit 1.2
    - ls loss name to squared_error loss name in scikit 1.2
    - deviance loss name to log_loss loss name in scikit 1.3
    """
    if (state.get('_sklearn_version', '999') < '1') and package_is_at_least(sklearn, "1"):
        _swap_variables(state, 'n_features_', 'n_features_in_')

    if (state.get('_sklearn_version', '999') < '1.1') and package_is_at_least(sklearn, "1.1"):
        _swap_variables(state, 'loss_', '_loss')

    if package_is_at_least(sklearn, "1.3"):
        _replace_value(state, "loss", "deviance", "log_loss")
    if package_is_at_least(sklearn, "1.2"):
        _replace_value(state, "loss", "lad", "absolute_error")
        _replace_value(state, "loss", "ls", "squared_error")


class UnpicklableGradientBoostingRegressor(GradientBoostingRegressor, object):
    def __setstate__(self, d):
        update_gradient_boosting_model_state(d)
        super(UnpicklableGradientBoostingRegressor, self).__setstate__(d)
        # sklearn 1.4 removes module sklearn.ensemble._gb_losses and all its losses
        if package_is_at_least(sklearn, "1.4") and isinstance(self._loss, RemovedGradientBoostingLoss):
            # Recreating _loss object from defined "loss" with new classes.
            # Sample weights are not available at this stage and set to None, but they are only needed for training
            self._loss = self._get_loss(None)


class UnpicklableGradientBoostingClassifier(GradientBoostingClassifier, object):
    def __setstate__(self, d):
        update_gradient_boosting_model_state(d)
        super(UnpicklableGradientBoostingClassifier, self).__setstate__(d)
        # sklearn 1.4 removes module sklearn.ensemble._gb_losses and all its losses
        if package_is_at_least(sklearn, "1.4") and isinstance(self._loss, RemovedGradientBoostingLoss):
            # Recreating _loss object from defined "loss" with new classes.
            # Sample weights are not available at this stage and set to None, but they are only needed for training
            self._loss = self._get_loss(None)


# Empty class used to allow unpickling of removed sklearn<1.4 loss classes
class RemovedGradientBoostingLoss(object):
    def __setstate__(self, d):
        pass
