import sklearn

from sklearn.linear_model import Lasso
from sklearn.linear_model import LassoLars
from sklearn.linear_model import LassoLarsCV
from sklearn.linear_model import Ridge
from sklearn.linear_model import RidgeCV
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

from dataiku.base.utils import package_is_at_least


def dku_fit(estimator, train_X, train_y, **kwargs):
    """
    Replace the deprecated calls to linear models using the parameter normalize=True which
    we used to pass, and has been removed in scikit 1.2
    We instead replace such models with a pipeline with a normalizing scaler (see https://scikit-learn.org/stable/whats_new/v1.0.html)
    Note that this isn't exactly the same behaviour as before but :
    - the objective is the same, it's making the models converge faster
    - the end model obtained will anyway be functionally equivalent, since we use it only on linear models on which there is
    always a single set of equivalent solutions.
    :return: A fit estimator
    """
    if isinstance(estimator, (Ridge, RidgeCV, Lasso, LassoLars, LassoLarsCV, LinearRegression)):
        if package_is_at_least(sklearn, "1.2"):
            scaler = StandardScaler(with_mean=False, with_std=True)
            scaler.fit(train_X)
            estimator.fit(scaler.transform(train_X), train_y, **kwargs)
            estimator.coef_ = estimator.coef_ / scaler.scale_
        else:
            estimator.normalize = True
            estimator.fit(train_X, train_y, **kwargs)
    else:
        estimator.fit(train_X, train_y, **kwargs)
