import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

from dataiku.base.utils import package_is_at_least


class OVRLogisticRegression(OneVsRestClassifier):
    """ This class makes OneVsRestClassifier backward compatible with DSS code. It is only useful on scikit-learn>=1.5 and required on scikit-learn>=1.7
    """
    def __init__(self, **kwargs):
        assert package_is_at_least(sklearn, "1.5"), self.__class__.__name__ + " requires scikit-learn>=1.5"
        estimator = LogisticRegression(**kwargs)
        super().__init__(estimator=estimator)

    def set_params(self, **params):
        self.estimator.set_params(**params)

    def get_params(self, **params):
        return self.estimator.get_params(**params)

    def fit(self, *args, **kwargs):
        if "sample_weight" in kwargs:
            sklearn.set_config(enable_metadata_routing=True)
            self.estimator.set_fit_request(sample_weight=True)
        return super().fit(*args, **kwargs)

    @property
    def coef_(self):
        return self.estimators_[0].coef_

    @property
    def intercept_(self):
        return self.estimators_[0].intercept_


def dku_logistic_regression(**kwargs):
    """ Returns a LogisticRegression, with backward compatibility when using multi_class='ovr' on scikit-learn>=1.5
    """
    if package_is_at_least(sklearn, "1.5"):
        multi_class = kwargs.get("multi_class")
        if multi_class is not None:
            del kwargs["multi_class"]
        if multi_class == 'ovr':
           return OVRLogisticRegression(**kwargs)
    return LogisticRegression(**kwargs)