import numpy as np
import sklearn
from numpy.lib import recfunctions
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, ExtraTreeClassifier, ExtraTreeRegressor
from sklearn.tree._tree import Tree
from sklearn.utils.validation import check_is_fitted

from dataiku.base.utils import package_is_at_least

DKU_SKLEARN_PRE_14_FLAG = '__dku_is_sklearn_pre_14'

def sklearn_1_4__add_new_tree_params(d):
    # In scikit-learn 1.4, Tree class has a new params "ccp_alpha" and "monotonic_cst" which are respectively set to 0 and None by default
    # https://github.com/scikit-learn/scikit-learn/pull/13649/files#diff-29904759e842b861caa6453a852d5f4eb958b8ca478ce6a1a55612c2832f10adR1282
    if package_is_at_least(sklearn, "1.4"):
        if "ccp_alpha" not in d:
            d['ccp_alpha'] = 0
        if "monotonic_cst" not in d:
            d['monotonic_cst'] = None
            d[DKU_SKLEARN_PRE_14_FLAG] = True


# Version of DecisionTreeClassifier.predict_proba used in sklearn 1.3.2
# In 1.4.0 the tree_.value attribute in tree.DecisionTreeClassifier, tree.DecisionTreeRegressor, tree.ExtraTreeClassifier and tree.ExtraTreeRegressor
# changed from a weighted absolute count of number of samples to a weighted fraction of the total number of samples.
# This is causing prediction mismatches for models trained before 1.4.0, as a workaround we are overriding the method to use previous version when unpickling such model.
# https://scikit-learn.org/stable/whats_new/v1.4.html#changed-models
# https://github.com/scikit-learn/scikit-learn/pull/27639/files?diff=split&w=0
def sklearn_1_3_2__DecisionTreeClassifier__predict_proba(decision_tree_object, X, check_input=True):
    """Predict class probabilities of the input samples X.
    The predicted class probability is the fraction of samples of the same
    class in a leaf.
    Parameters
    ----------
    X : {array-like, sparse matrix} of shape (n_samples, n_features)
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    check_input : bool, default=True
        Allow to bypass several input checking.
        Don't use this parameter unless you know what you're doing.
    Returns
    -------
    proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \
        such arrays if n_outputs > 1
        The class probabilities of the input samples. The order of the
        classes corresponds to that in the attribute :term:`classes_`.
    """
    check_is_fitted(decision_tree_object)
    X = decision_tree_object._validate_X_predict(X, check_input)
    proba = decision_tree_object.tree_.predict(X)

    if decision_tree_object.n_outputs_ == 1:
        proba = proba[:, : decision_tree_object.n_classes_]
        normalizer = proba.sum(axis=1)[:, np.newaxis]
        normalizer[normalizer == 0.0] = 1.0
        proba /= normalizer

        return proba

    else:
        all_proba = []

        for k in range(decision_tree_object.n_outputs_):
            proba_k = proba[:, k, : decision_tree_object.n_classes_[k]]
            normalizer = proba_k.sum(axis=1)[:, np.newaxis]
            normalizer[normalizer == 0.0] = 1.0
            proba_k /= normalizer
            all_proba.append(proba_k)

        return all_proba


class UnpicklableTree(Tree, object):

    def __setstate__(self, d):
        # In scikit-learn 1.3, the "nodes" attribute of the Tree class has a new param "missing_go_to_left"
        # This is very probably to handle a new way of managing sparse data in the versions. We can be certain
        # that any Tree from a previous version is meant to have the old behaviour, so we force "1" in all the nodes
        if package_is_at_least(sklearn, "1.3") and "missing_go_to_left" not in d["nodes"].dtype.names:
            d["nodes"] = recfunctions.append_fields(d["nodes"], 'missing_go_to_left',
                                                    np.ones(len(d["nodes"]), np.uint8), usemask=False)

        sklearn_1_4__add_new_tree_params(d)

        super(UnpicklableTree, self).__setstate__(d)


class UnpicklableDecisionTreeRegressor(DecisionTreeRegressor, object):

    def __setstate__(self, d):
        sklearn_1_4__add_new_tree_params(d)
        super(DecisionTreeRegressor, self).__setstate__(d)


class UnpicklableDecisionTreeClassifier(DecisionTreeClassifier, object):

    def __setstate__(self, d):
        sklearn_1_4__add_new_tree_params(d)
        super(DecisionTreeClassifier, self).__setstate__(d)

    def predict_proba(self, X, check_input=True):
        if package_is_at_least(sklearn, "1.4") and hasattr(self, DKU_SKLEARN_PRE_14_FLAG):
            return sklearn_1_3_2__DecisionTreeClassifier__predict_proba(self, X, check_input)
        else:
            return super().predict_proba(X, check_input)


class UnpicklableRandomForestClassifier(RandomForestClassifier, object):

    def __setstate__(self, d):
        sklearn_1_4__add_new_tree_params(d)
        super(RandomForestClassifier, self).__setstate__(d)


class UnpicklableExtraTreeClassifier(ExtraTreeClassifier, object):

    def __setstate__(self, d):
        sklearn_1_4__add_new_tree_params(d)
        super(ExtraTreeClassifier, self).__setstate__(d)

    def predict_proba(self, X, check_input=True):
        if package_is_at_least(sklearn, "1.4") and hasattr(self, DKU_SKLEARN_PRE_14_FLAG):
            return sklearn_1_3_2__DecisionTreeClassifier__predict_proba(self, X, check_input)
        else:
            return super().predict_proba(X, check_input)


class UnpicklableExtraTreeRegressor(ExtraTreeRegressor, object):

    def __setstate__(self, d):
        sklearn_1_4__add_new_tree_params(d)
        super(ExtraTreeRegressor, self).__setstate__(d)
