import json
import re

import numpy as np
import sklearn

from sklearn.calibration import CalibratedClassifierCV

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.prediction.prediction_interval_model import PredictionIntervalsModel
from dataiku.doctor.prediction.scorable_model import ScorableModelBinary
from dataiku.doctor.prediction.scorable_model import ScorableModelMulticlass
from dataiku.doctor.prediction.scorable_model import ScorableModelRegression
from dataiku.doctor.prediction.scorable_model import ScorableModelRegressionWithPredictionIntervals
from dataiku.doctor.utils.model_io import dump_model_to_folder
from dataiku.doctor.utils.skcompat import extract_X_y_from_isotonic_regressor
from dataiku.doctor.utils.skcompat import get_base_estimator
from dataiku.doctor.utils.skcompat import get_calibrators
from dataiku.doctor.utils.skcompat import get_gbt_binary_classification_baseline
from dataiku.doctor.utils.skcompat import get_gbt_multiclass_classification_baseline
from dataiku.doctor.utils.skcompat import get_gbt_regression_baseline
from dataiku.doctor.utils.skcompat._logistic_regression import OVRLogisticRegression


class SerializedModel(object):
    def __init__(self, algorithm_name, serialized_clf):
        """
        :type algorithm_name: str
        :type serialized_clf: dict
        """
        self.algorithm_name = algorithm_name
        self.serialized_clf = serialized_clf


class ModelSerializer(object):

    @staticmethod
    def build(model, run_folder_context, columns, calibrate_proba=False):
        """

        :type model: dataiku.doctor.prediction.scorable_model.SerializableMixin
        :type run_folder_context: dataiku.base.folder_context.FolderContext
        :type columns: list
        :type calibrate_proba: bool
        :rtype: ModelSerializer
        """

        if isinstance(model, ScorableModelRegressionWithPredictionIntervals):
            return RegressionModelSerializerWithPredictionIntervals(columns, model, run_folder_context)
        if isinstance(model, ScorableModelRegression):
            return RegressionModelSerializer(columns, model, run_folder_context)
        if isinstance(model, ScorableModelMulticlass):
            return MulticlassModelSerializer(columns, model, run_folder_context, calibrate_proba)
        if isinstance(model, ScorableModelBinary):
            return BinaryModelSerializer(columns, model, run_folder_context, calibrate_proba)
        raise ValueError("Scorable model type not supported: " + model.__class__.__name__)

    def __init__(self, columns, model, model_folder_context):
        """
        :type columns: list
        :type model: dataiku.doctor.prediction.scorable_model.SerializableMixin
        :type model_folder_context: dataiku.base.folder_context.FolderContext
        """
        self.model = model
        self.algorithm = model.algorithm
        self.columns = columns
        self.clf = model.clf
        self.model_folder_context = model_folder_context

    def get_serialized(self):
        """
        Returns the serialized object for this model, which includes both the algorithm name to serialize and the clf
        data
        :rtype: SerializedModel or None
        """
        return None

    def _get_meta(self, algorithm_name):
        meta = {
            "backend": "KERAS" if self.algorithm == "KERAS_CODE" else "PY_MEMORY",
            "algorithm_name": algorithm_name,
            "columns": self.columns
        }
        return meta

    def serialize(self, pkl_filename="clf.pkl", gz_filename="dss_pipeline_model.gz", meta_filename="dss_pipeline_meta.json"):
        """
        Dump all relevant model-related information to the run_folder. This includes
            - the pickled classifier
            - the serialized model
            - the final preprocessed column names, in the order in which they are used by the model
            - in the case of binary or multiclass classification, the class mapping
        :type pkl_filename: str
        :type gz_filename: str
        :type meta_filename: str
        :rtype: None
        """
        dump_model_to_folder(self.clf, self.model_folder_context, pkl_filename=pkl_filename)
        serialized = self.get_serialized()
        if serialized is not None:
            self.model_folder_context.write_json(meta_filename, self._get_meta(serialized.algorithm_name))
            self.model_folder_context.write_json(gz_filename, serialized.serialized_clf)


def _listify(o):
    """ recursively listify arrays or lists """
    if type(o) is np.ndarray or type(o) is list:
        return [_listify(x) for x in o]
    else:
        return o


# this is not super efficient, in particular with large forests. May have to vectorize if turns out to be slow
def _serialize_sklearn_tree(tree, is_regression):
    extract = tree.tree_
    left = extract.children_left.tolist()
    right = extract.children_right.tolist()
    extract_thresholds = extract.threshold.tolist()
    extract_labels = extract.value.tolist()
    extract_features = extract.feature.tolist()
    # missing_go_to_left attribute only available from version 1.3 of scikit-learn
    extract_missing_goes_left = extract.missing_go_to_left.tolist() if hasattr(extract, "missing_go_to_left") else None
    node_ids = []
    leaf_ids = []
    labels = []
    features = []
    thresholds = []
    missing = []

    def process(index, id):
        # not a leaf
        if left[index] >= 0:
            node_ids.append(id)
            features.append(extract_features[index])
            thresholds.append(extract_thresholds[index])
            if extract_missing_goes_left is not None:
                missing.append("l" if extract_missing_goes_left[index] else "r")
            # id of current node children are 2*id + {1, 2} (NB: root id is 0, not 1)
            process(left[index], 2 * id + 1)
            process(right[index], 2 * id + 2)
        # no child => this is a leaf
        else:
            leaf_ids.append(id)
            if is_regression:
                labels.append(extract_labels[index][0][0])
            else:
                tab = extract_labels[index][0]
                norm = 1.0 / sum(tab)
                labels.append([x * norm for x in tab])

    process(0, 0)

    return {
        "node_id": node_ids,
        "feature": features,
        "threshold": thresholds,
        "leaf_id": leaf_ids,
        "label": labels,
        "missing": missing,
    }


def _serialize_decision_forest(forest, is_regression):
    return {"trees": [_serialize_sklearn_tree(t, is_regression) for t in forest.estimators_]}


def _serialize_regression_gbm(gbm, is_regression):
    return {
        # note that we do t[0] because scikit wraps gbm trees in another array because why the f*** not
        "trees": [_serialize_sklearn_tree(t[0], is_regression) for t in gbm.estimators_],
        "shrinkage": gbm.learning_rate,
        "baseline": get_gbt_regression_baseline(gbm)
    }


def _serialize_classification_gbm(gbm, is_binary):
    shrinkage = gbm.learning_rate
    if is_binary:
        baseline = get_gbt_binary_classification_baseline(gbm)
        if gbm.loss == "exponential":
            # In sklearn, when loss is "exponential", the predicted probas are: sigmoid(2 * score),
            # while it's sigmoid(score) for the default "deviance" loss. We don't handle this case
            # in DSS, so we always compute sigmoid(score). To overcome this issue, we multiply both
            # the baseline and the shrinkage by 2. This is valid because:
            # 2*(baseline + shrinkage * trees_preds) == 2*baseline + 2*shrinkage * trees_preds
            [baseline_value] = baseline
            baseline = [2*baseline_value]
            shrinkage *= 2.0
        return SerializedModel("GRADIENT_BOOSTING_CLASSIFIER", {
            "trees": [[_serialize_sklearn_tree(t[0], True)] for t in gbm.estimators_],
            "shrinkage": shrinkage,
            "baseline": baseline
        })
    else:
        return SerializedModel("GRADIENT_BOOSTING_CLASSIFIER", {
            "trees": [[_serialize_sklearn_tree(t, True) for t in trees] for trees in gbm.estimators_],
            "shrinkage": shrinkage,
            "baseline": get_gbt_multiclass_classification_baseline(gbm)
        })


def _serialize_xgboost_tree_dict(tree_dict):
    # Get the leaf data
    leaf_ids = [index for index, children in enumerate(tree_dict["left_children"]) if children == -1]
    # Ignore the orphan leaves that xgboost sometimes add as children of leaves
    leaf_ids = [leaf_id for leaf_id in leaf_ids if tree_dict["parents"][leaf_id] not in leaf_ids]
    labels = [tree_dict["split_conditions"][i] for i in leaf_ids]
    # We'll reconstruct the tree from the deepest leaves, sort them by id desc
    leaf_ids.reverse()
    labels.reverse()
    # Get the nodes data
    node_ids = [index for index, children in enumerate(tree_dict["left_children"]) if children != -1]
    thresholds = [tree_dict["split_conditions"][i] for i in node_ids]
    features = [tree_dict["split_indices"][i] for i in node_ids]
    missing = ['l' if tree_dict["default_left"][i] == 1 else 'r' for i in node_ids]

    # Make sure that a leaf parent is always at floor(leaf_id -1 /2) for easy reconstruction
    # Also make sure to construct the remapping layer by layer (since each layer affects initial condition
    # of the next one), which conveniently corresponds to the array order used by xgboost
    index_remapping = {0: 0}
    for index, children in enumerate(tree_dict["left_children"]):
        index_remapping[children] = 2*index_remapping.get(index, index)+1
        index_remapping[tree_dict["right_children"][index]] = 2 * index_remapping.get(index, index) + 2
    node_ids = [index_remapping[id] for id in node_ids]
    leaf_ids = [index_remapping[id] for id in leaf_ids]
    return {
        "node_id": node_ids,
        "feature": features,
        "threshold": thresholds,
        "missing": missing,
        "leaf_id": leaf_ids,
        "label": labels,
        "xgboost": True
    }

def _serialize_xgboost_tree(dump):
    all_nodes = [node.strip() for node in dump.split("\n") if node.strip() != ""]
    leaf_id = []
    label = []
    node_id = []
    feature = []
    missing = []
    threshold = []
    index_remapping = {0: 0}
    for node in all_nodes:
        tmp = node.split(":")
        xgb_index = int(tmp[0])
        index = index_remapping[xgb_index]
        is_leaf = tmp[1][:4] == "leaf"
        if is_leaf:
            leaf_id.append(index)
            label.append(float(tmp[1].split("=")[1]))
        else:
            if "<" in tmp[1]:
                # for cases FeatureMap::kInteger, FeatureMap::kFloat, FeatureMap::kQuantitative (default)
                f, t, il, ir, m = re.search(r"\[(.*)<(.*)\] yes=(.*),no=(.*),missing=(.*)", tmp[1]).groups()
                index_remapping[int(il)] = 2*index + 1
                index_remapping[int(ir)] = 2*index + 2
                if m == il:
                    m = 'l'
                else:
                    m = 'r'
            else:
                # boolean condition (FeatureMap::kIndicator) has no missing field
                f, il, ir, _ = re.search(r"\[(.*)\] yes=(.*),no=(.*)", tmp[1]).groups()
                t = 1.
                m = ''
                # NB: need to swap left and right as for boolean f as f = not float(f) < 1.
                index_remapping[int(il)] = 2*index + 2
                index_remapping[int(ir)] = 2*index + 1
            feature.append(int(f[1:]))
            node_id.append(index)
            missing.append(m)
            threshold.append(float(t))
    return {
        "node_id": node_id,
        "feature": feature,
        "threshold": threshold,
        "missing": missing,
        "leaf_id": leaf_id,
        "label": label,
        "xgboost": True
    }


def _serialize_regression_xgb(xgb_model):
    import xgboost
    # Starting with version 1.6, xgboost gives a method to get an exact json dump of the trees, so we use that
    # instead of parsing the raw dump in which threshold and labels are already ever so slightly inexactly serialized -__-
    if package_is_at_least(xgboost, "1.6"):
        return _serialize_regression_xgb_post16(xgb_model)
    else:
        return _serialize_regression_xgb_pre16(xgb_model)

def _serialize_regression_xgb_pre16(xgb_model):
    trees_as_dump = xgb_model.get_booster().get_dump()
    # If trained with early stopped, don't use the trees after best_ntree_limit (changed in xgboost==0.80)
    if hasattr(xgb_model, "best_ntree_limit"):
        trees_as_dump = trees_as_dump[:xgb_model.best_ntree_limit]
    # TODO: take bias and weight into account for gblinear models when added to DSS
    gamma_regression = xgb_model.get_params().get("objective") in ["reg:gamma", "reg:tweedie", "count:poisson"]
    logistic_regression = xgb_model.get_params().get("objective") == "reg:logistic"
    return {
        "trees": [_serialize_xgboost_tree(t) for t in trees_as_dump],
        "shrinkage": 1.,
        "baseline": xgb_model.base_score,
        "gamma_regression": gamma_regression,
        "logistic_regression": logistic_regression
    }


def _serialize_regression_xgb_post16(xgb_model):
    model_json = json.loads(xgb_model.get_booster().save_raw("json").decode())
    if model_json["learner"]["gradient_booster"]["name"] == "dart":
        trees_json = model_json["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]
    else:
        trees_json = model_json["learner"]["gradient_booster"]["model"]["trees"]
    if "best_ntree_limit" in model_json["learner"]["attributes"]:
        trees_json = trees_json[:int(model_json["learner"]["attributes"]["best_ntree_limit"])]
    # TODO: take bias and weight into account for gblinear models when added to DSS
    gamma_regression = model_json["learner"]["objective"]["name"] in ["reg:gamma", "reg:tweedie", "count:poisson"]
    logistic_regression = model_json["learner"]["objective"]["name"] == "reg:logistic"
    return {
        "trees":  [_serialize_xgboost_tree_dict(tree) for tree in trees_json],
        "shrinkage": 1.,
        "baseline": model_json["learner"]["learner_model_param"]["base_score"],
        "gamma_regression": gamma_regression,
        "logistic_regression": logistic_regression
    }

def _serialize_classification_xgb(xgb_model, is_binary):
    import xgboost
    # Starting with version 1.6, xgboost gives a method to get an exact json dump of the trees, so we use that
    # instead of parsing the raw dump in which threshold and labels are already ever so slightly inexactly serialized -__-
    if package_is_at_least(xgboost, "1.6"):
        return _serialize_classification_xgb_post16(xgb_model, is_binary)
    else:
        return _serialize_classification_xgb_pre16(xgb_model, is_binary)


def _serialize_classification_xgb_pre16(xgb_model, is_binary):
    trees_as_dump = xgb_model.get_booster().get_dump()
    # TODO: take bias and weight into account for gblinear models when added to DSS
    if is_binary:
        # If trained with early stopped, don't use the trees after best_ntree_limit (changed in xgboost==0.80)
        if hasattr(xgb_model, "best_ntree_limit"):
            trees_as_dump = trees_as_dump[:xgb_model.best_ntree_limit]
        return {
            "trees": [[_serialize_xgboost_tree(t)] for t in trees_as_dump],
            "shrinkage": 1.,
            "baseline": [0.],
        }
    else:
        n_classes = xgb_model.n_classes_
        # If trained with early stopped, don't use the trees after best_ntree_limit (changed in xgboost==0.80)
        if hasattr(xgb_model, "best_ntree_limit"):
            trees_as_dump = trees_as_dump[:xgb_model.best_ntree_limit*n_classes]
        estimators = [[] for _ in range(n_classes)]
        for i, t in enumerate(trees_as_dump):
            estimators[i % n_classes].append(_serialize_xgboost_tree(t))

        return {
            "trees": np.array(estimators).T.tolist(),
            "shrinkage": 1.,
            "baseline": [0.] * n_classes
        }


def _serialize_classification_xgb_post16(xgb_model, is_binary):
    model_json = json.loads(xgb_model.get_booster().save_raw("json").decode())
    if model_json["learner"]["gradient_booster"]["name"] == "dart":
        trees_json = model_json["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]
    else:
        trees_json = model_json["learner"]["gradient_booster"]["model"]["trees"]
    # TODO: take bias and weight into account for gblinear models when added to DSS
    if is_binary:
        if "best_ntree_limit" in model_json["learner"]["attributes"]:
            trees_json = trees_json[:int(model_json["learner"]["attributes"]["best_ntree_limit"])]
        return {
            "trees": [[_serialize_xgboost_tree_dict(tree)] for tree in trees_json],
            "shrinkage": 1.,
            "baseline": [0.],
        }
    else:
        n_classes = xgb_model.n_classes_
        if "best_ntree_limit" in model_json["learner"]["attributes"]:
            trees_json = trees_json[:int(model_json["learner"]["attributes"]["best_ntree_limit"]) * n_classes]
        estimators = [[] for _ in range(n_classes)]
        for i, tree in enumerate(trees_json):
            estimators[i % n_classes].append(_serialize_xgboost_tree_dict(tree))
        return {
            "trees": np.array(estimators).T.tolist(),
            "shrinkage": 1.,
            "baseline": [0.] * n_classes
        }


def _serialize_lightgbm_tree(dumped_tree):
    node_ids = []
    features = []
    thresholds = []
    leaf_ids = []
    labels = []
    missings = []

    # We walk the tree in breadth-first order and encode the nodes iteratively.
    root_node = dumped_tree["tree_structure"]
    nodes_to_process = [(root_node, 0)]

    while len(nodes_to_process) > 0:
        node, remapped_index = nodes_to_process.pop(0)

        if "leaf_value" in node:
            leaf_ids.append(remapped_index)
            labels.append(node["leaf_value"])
            continue

        decision_type = node["decision_type"]

        if decision_type != "<=":
            # There is also a "==" decision type, but only when the LightGBM
            # native categorical features support is enabled, which is not
            # the case for now.
            raise NotImplementedError("Cannot serialize LightGBM tree because of unimplemented decision_type='{}'".format(decision_type))

        node_ids.append(remapped_index)
        features.append(node["split_feature"])
        thresholds.append(node["threshold"])
        missings.append("l" if node["default_left"] else "r")

        # Add the current node children to the list of nodes to explore.
        left_child_to_explore = (node["left_child"], 2 * remapped_index + 1)
        nodes_to_process.append(left_child_to_explore)

        right_child_to_explore = (node["right_child"], 2 * remapped_index + 2)
        nodes_to_process.append(right_child_to_explore)

    return {
        "node_id": node_ids,
        "feature": features,
        "threshold": thresholds,
        "missing": missings,
        "leaf_id": leaf_ids,
        "label": labels,
        "lightgbm": True
    }


def _serialize_lightgbm_regressor(estimator):
    dump_json = estimator.booster_.dump_model()
    tree_info = dump_json["tree_info"]
    serialized_trees = [_serialize_lightgbm_tree(tree) for tree in tree_info]

    serialized_model = {
        "trees": serialized_trees,
        "shrinkage": 1.,
        "baseline": 0.,
        "gamma_regression": estimator.objective_ == "gamma"
    }

    return SerializedModel("GRADIENT_BOOSTING_REGRESSOR", serialized_model)


def _serialize_lightgbm_classifier(estimator, is_binary):
    dump_json = estimator.booster_.dump_model()
    tree_info = dump_json["tree_info"]

    if is_binary:
        baseline = [0.]
        serialized_trees = [[_serialize_lightgbm_tree(tree)] for tree in tree_info]
    else:
        num_classes = estimator.n_classes_
        baseline = [0.] * num_classes
        serialized_trees_per_class = [[] for _ in range(num_classes)]

        for i, tree in enumerate(tree_info):
            serialized_tree = _serialize_lightgbm_tree(tree)
            serialized_trees_per_class[i % num_classes].append(serialized_tree)

        # We need to reorder the trees to fit the expected format.
        serialized_trees = np.array(serialized_trees_per_class).T.tolist()

    serialized_model = {
        "trees": serialized_trees,
        "shrinkage": 1.,
        "baseline": baseline,
        "gamma_regression": estimator.objective_ == "gamma"
    }

    return SerializedModel("GRADIENT_BOOSTING_CLASSIFIER", serialized_model)


def _serialize_mlp(clf):
    return SerializedModel("MULTI_LAYER_PERCEPTRON", {
        "activation": clf.activation.upper(),
        "biases": _listify(clf.intercepts_),
        "weights": [_listify(np.transpose(x)) for x in clf.coefs_]
    })


class RegressionModelSerializer(ModelSerializer):

    def get_serialized(self):
        algo = self.algorithm
        # Ridge, Lasso, OLS, SGD ...
        if hasattr(self.clf, 'coef_') and hasattr(self.clf, 'intercept_') and algo != "SVM_REGRESSION":
            # for SGDRegressor, intercept_ comes as a (1,) ndarray, so we need to convert to float
            return SerializedModel("LINEAR", {
                "coefficients": self.clf.coef_,
                "intercept": float(self.clf.intercept_)
            })

        if algo == "DECISION_TREE_REGRESSION":
            return SerializedModel("DECISION_TREE", _serialize_sklearn_tree(self.clf, True))

        if algo == "RANDOM_FOREST_REGRESSION" or algo == "EXTRA_TREES":
            return SerializedModel("FOREST_REGRESSOR", _serialize_decision_forest(self.clf, True))

        if algo == "GBT_REGRESSION":
            return SerializedModel("GRADIENT_BOOSTING_REGRESSOR", _serialize_regression_gbm(self.clf, True))

        if algo == "NEURAL_NETWORK":
            return _serialize_mlp(self.clf)

        if algo == "XGBOOST_REGRESSION":
            return SerializedModel("GRADIENT_BOOSTING_REGRESSOR", _serialize_regression_xgb(self.clf))

        if algo == "LIGHTGBM_REGRESSION":
            return _serialize_lightgbm_regressor(self.clf)

        return None


class PredictionIntervalsModelSerializer(RegressionModelSerializer):
    def __init__(self, model, model_folder_context):
        """

        :type model: PredictionIntervalsModel
        :type model_folder_context: dataiku.base.folder_context.FolderContext
        """
        super(PredictionIntervalsModelSerializer, self).__init__([], model, model_folder_context)
        self.q = model.q

    def _get_meta(self, algorithm_name):
        """
        We can reuse this method to save the parameters of the prediction intervals model,
        since the meta is saved on the main model.
        :param str algorithm_name: algo name used for java deserialization
        :rtype: dict
        """
        return {
            "q": self.q,
            "algorithm_name": algorithm_name
        }


class RegressionModelSerializerWithPredictionIntervals(RegressionModelSerializer):
    def __init__(self, columns, model, run_folder):
        """
        :type columns: list
        :type model: dataiku.doctor.prediction.scorable_model.ScorableModelRegressionWithConfidenceIntervals
        :type run_folder: dataiku.base.folder_context.FolderContext
        """
        super(RegressionModelSerializerWithPredictionIntervals, self).__init__(columns, model, run_folder)
        self.intervals_model_serializer = PredictionIntervalsModelSerializer(model.prediction_intervals_model, run_folder)

    def serialize(self, pkl_filename="clf.pkl", gz_filename="dss_pipeline_model.gz", meta_filename="dss_pipeline_meta.json"):
        """
        Dump all relevant model-related information to the run_folder. This includes
            - the pickled classifier
            - the serialized model
            - the final preprocessed column names, in the order in which they are used by the model
            - in the case of binary or multiclass classification, the class mapping
        :type pkl_filename: str
        :type gz_filename: str
        :type meta_filename: str
        :rtype: None
        """
        super(RegressionModelSerializerWithPredictionIntervals, self).serialize(pkl_filename, gz_filename, meta_filename)
        self.intervals_model_serializer.serialize(
            pkl_filename=PredictionIntervalsModel.PKL_FILENAME,
            gz_filename=PredictionIntervalsModel.GZ_FILENAME,
            meta_filename=PredictionIntervalsModel.PARAMS_FILENAME,
        )


def _common_classif_serialization(algo, clf):
    # Logistic Regression, SGD ...
    if algo == "DECISION_TREE_CLASSIFICATION":
        return SerializedModel("DECISION_TREE", _serialize_sklearn_tree(clf, False))
    elif algo == "RANDOM_FOREST_CLASSIFICATION" or algo == "EXTRA_TREES":
        return SerializedModel("FOREST_CLASSIFIER", _serialize_decision_forest(clf, False))
    elif algo == "NEURAL_NETWORK":
        return _serialize_mlp(clf)
    else:
        return None


def _serialize_binary_logit(clf):
    return _serialize_binary_logistic(clf, "MULTINOMIAL")

def _serialize_binary_sgd(clf):
    if clf.loss in ['log', 'log_loss'] :
        return _serialize_binary_logistic(clf, 'MULTINOMIAL')
    elif clf.loss == 'modified_huber':
        return _serialize_binary_logistic(clf, 'MODIFIED_HUBER')
    else:
        return None

def _serialize_binary_logistic(clf, policy):
    # to be compatible, we create dummy coefficients and intercept for the 0 class, all equal to zero, and treat the
    # model as multinomial (100% kosher)
    model_coef = clf.coef_.tolist()
    dummy_coef = [0.0 for x in model_coef[0]]
    model = {
        "policy": policy,
        "coefficients": [dummy_coef] + model_coef,
        "intercept": [0.0] + clf.intercept_.tolist()
    }
    return SerializedModel("LOGISTIC", model)


class ClassificationModelSerializer(ModelSerializer):
    def __init__(self, columns, model, run_folder_context, calibrate_proba=False):
        """

        :type columns: list
        :type model: dataiku.doctor.prediction.scorable_model.ScorableModelClassification
        :type run_folder_context: dataiku.base.folder_context.FolderContext
        :type calibrate_proba: bool
        """
        super(ClassificationModelSerializer, self).__init__(columns, model, run_folder_context)
        self.target_mapping = model.target_map
        self.calibrate_proba = calibrate_proba

    def add_calibrator(self, model):
        if model is not None:
            calibrator = self._get_calibrator() if self.calibrate_proba else {}
            model.serialized_clf["calibrator"] = calibrator

    def _get_meta(self, name):
        meta = super(ClassificationModelSerializer, self)._get_meta(name)
        # because scikit does it own class mapping, we have to remap here. So the final classes will be different
        # from the target_mapping if some were missing from the training set
        inv_mapping = {x[1]: x[0] for x in self.target_mapping.items()}
        meta["classes"] = [inv_mapping[i] for i in self.clf.classes_]
        return meta

    def _get_calibrator(self):
        """
        Returns a serializable dict containing the calibration parameters
        """
        if not self.calibrate_proba or not isinstance(self.clf, CalibratedClassifierCV):
            raise ValueError("Cannot get calibrator of model that has not been calibrated")
        from_proba = not hasattr(get_base_estimator(self.clf), "decision_function")
        calibrators = get_calibrators(self.clf)
        if self.clf.method == "sigmoid":
            a_arr, b_arr = zip(*[(calibrator.a_, calibrator.b_) for calibrator in calibrators])
            return {
                "method": "SIGMOID",
                "from_proba": from_proba,
                "a_array": a_arr,
                "b_array": b_arr
            }
        elif self.clf.method == "isotonic":
            x_arr, y_arr = zip(*[extract_X_y_from_isotonic_regressor(calibrator) for calibrator in calibrators])
            return {
                "method": "ISOTONIC",
                "from_proba": from_proba,
                "x_array": x_arr,
                "y_array": y_arr
            }


class BinaryModelSerializer(ClassificationModelSerializer):
    def get_serialized(self):
        algo = self.algorithm
        if self.calibrate_proba:
            clf = get_base_estimator(self.clf)
        else:
            clf = self.clf
        if algo == "LOGISTIC_REGRESSION":
            model = _serialize_binary_logit(clf)
        elif algo == "SGD_CLASSIFICATION":
            model = _serialize_binary_sgd(clf)
        elif algo == "GBT_CLASSIFICATION":
            model = _serialize_classification_gbm(clf, True)
        elif algo == "XGBOOST_CLASSIFICATION":
            model = SerializedModel("GRADIENT_BOOSTING_CLASSIFIER", _serialize_classification_xgb(clf, True))
        elif algo == "LIGHTGBM_CLASSIFICATION":
            model = _serialize_lightgbm_classifier(clf, True)
        else:
            model = _common_classif_serialization(algo, clf)
        self.add_calibrator(model)
        return model


def _serialize_multicass_logit(clf):
    # starting with scikit 1.5 multinomial is used automatically if n_classes >= 3
    if package_is_at_least(sklearn, "1.5"):
        if isinstance(clf, OVRLogisticRegression):
            multi_class = "ONE_VERSUS_ALL"
        else:
            multi_class = "MULTINOMIAL"
    else:
        if clf.multi_class == "multinomial":
            multi_class = "MULTINOMIAL"
        else:
            multi_class = "ONE_VERSUS_ALL"
    return _serialize_multicass_logistic(clf, multi_class)

def _serialize_multicass_sgd(clf):
    if clf.loss in ['log', 'log_loss']:
        return _serialize_multicass_logistic(clf, "ONE_VERSUS_ALL")
    elif clf.loss == 'modified_huber':
        return _serialize_multicass_logistic(clf, "MODIFIED_HUBER")
    else:
        return None

def _serialize_multicass_logistic(clf, policy):
    model = {
        "coefficients": clf.coef_.tolist(),
        "intercept": clf.intercept_.tolist(),
        "policy": policy
    }
    return SerializedModel("LOGISTIC", model)


class MulticlassModelSerializer(ClassificationModelSerializer):
    def get_serialized(self):
        algo = self.algorithm
        if self.calibrate_proba:
            clf = get_base_estimator(self.clf)
        else:
            clf = self.clf
        if algo == "LOGISTIC_REGRESSION":
            model = _serialize_multicass_logit(clf)
        elif algo == "SGD_CLASSIFICATION":
            model = _serialize_multicass_sgd(clf)
        elif algo == "GBT_CLASSIFICATION":
            model = _serialize_classification_gbm(clf, False)
        elif algo == "XGBOOST_CLASSIFICATION":
            model = SerializedModel("GRADIENT_BOOSTING_CLASSIFIER", _serialize_classification_xgb(clf, False))
        elif algo == "LIGHTGBM_CLASSIFICATION":
            model = _serialize_lightgbm_classifier(clf, False)
        else:
            model = _common_classif_serialization(algo, clf)
        self.add_calibrator(model)
        return model
