import logging
import sklearn

from dataiku.base.utils import package_is_at_least
from sklearn.cluster import KMeans

logger = logging.getLogger(__name__)


class UnpickableKMeans(KMeans, object):
    """
    Bridges difference in state between sklearn 0.20 and sklearn 0.23+ where n_jobs has been deprecated in favor
    of autocomputed _n_threads.
    """
    def __setstate__(self, d):
        super(KMeans, self).__setstate__(d)
        if (d.get('_sklearn_version', '999') < '0.23') and package_is_at_least(sklearn, "0.23"):
            if d.get('n_jobs', 1) is not None:
                n_threads = d['n_jobs']
            else:
                n_threads = 1
            logger.warning(u"Unpickling KMeans from sklearn version {} to {}. Forcing _n_threads value to {}".format(d['_sklearn_version'], sklearn.__version__, n_threads))
            self._n_threads = n_threads


def get_kmeans_estimator(**kwargs):
    """
    Returns the sklearn.cluster.KMeans estimator, handling various kwargs that changed between sklearn versions

    Before sk 0.23, n_jobs was used to control the number of parallel loops processing all the data, so had a strong impact on memory pressure.
    After sk 0.23, this is no longer the case since the parallelism is handled through data splitting, so this param isn't needed anymore
    Cf https://scikit-learn.org/stable/computing/parallelism.html for the difference between joblib and openmp based parellism
    """
    if package_is_at_least(sklearn, "0.23") and "n_jobs" in kwargs:
        del kwargs["n_jobs"]

    return KMeans(**kwargs)


def get_kmeans_n_init_value(modeling_params):
    """
    Prior to sk 1.4, the default n_init value was 10 - https://scikit-learn.org/1.3/modules/generated/sklearn.cluster.KMeans.html
    Post sk1.4, the default n_init value is 'auto', which resolves to 1 given that we use the default value for `init` https://scikit-learn.org/1.4/modules/generated/sklearn.cluster.KMeans.html
    """
    n_init = modeling_params.get("n_init")

    if n_init is None:
        if package_is_at_least(sklearn, "1.4"):
            logger.warning("KMeans n_init value not set. sklearn version >= 1.4 ({}) therefore using n_init value of 1".format(sklearn.__version__))
            n_init = 1
        else:
            logger.warning("KMeans n_init value not set. sklearn version <1.4 ({}) therefore using n_init value of 10".format(sklearn.__version__))
            n_init = 10

    return n_init


def get_minibatch_n_init_value(modeling_params):
    """
    Prior to sk 1.4, the default n_init value was 3 - https://scikit-learn.org/1.3/modules/generated/sklearn.cluster.MiniBatchKMeans.html
    Post sk1.4, the default n_init value is 'auto', which resolves to 1 given that we use the default value for `init` https://scikit-learn.org/1.4/modules/generated/sklearn.cluster.MiniBatchKMeans.html
    """
    n_init = modeling_params.get("n_init")

    if n_init is None:
        if package_is_at_least(sklearn, "1.4"):
            logger.warning("MiniBatchKMeans n_init value not set. sklearn version >= 1.4 ({}) therefore using n_init value of 1".format(sklearn.__version__))
            n_init = 1
        else:
            logger.warning("MiniBatchKMeans n_init value not set. sklearn version <1.4 ({}) therefore using n_init value of 3".format(sklearn.__version__))
            n_init = 3

    return n_init
