import numpy as np
import sklearn
from sklearn.svm import SVR, SVC
from dataiku.base.utils import package_is_at_least


def _update_state(d):
    if d.get('_sklearn_version') < '0.22' and package_is_at_least(sklearn, '0.22'):
        if 'probA_' in d:
            d['_probA'] = d.pop("probA_")
        if 'probB_' in d:
            d['_probB'] = d.pop("probB_")


class UnpicklableSVR(SVR, object):
    def __setstate__(self, d):
        _update_state(d)
        if d.get('_sklearn_version') < '0.22' and package_is_at_least(sklearn, '0.22'):
            # Bypasses both https://github.com/scikit-learn/scikit-learn/pull/21336/files
            # and https://github.com/scikit-learn/scikit-learn/pull/15099/files problems
            # TL;DR : scikit fucked this attribute. A lot.
            if 'support_' in d:
                d['_n_support'] = np.array([d["support_"].shape[0], 0], dtype=np.int32)
            if 'n_support_' in d:
                d.pop('n_support_')
        super(UnpicklableSVR, self).__setstate__(d)


class UnpicklableSVC(SVC, object):
    def __setstate__(self, d):
        _update_state(d)
        if d.get('_sklearn_version') < '0.22' and package_is_at_least(sklearn, '0.22'):
            d['break_ties'] = False
            if 'n_support_' in d:
                d['_n_support'] = d.pop('n_support_')
        super(UnpicklableSVC, self).__setstate__(d)
