# WARNING: Not to be imported directly in exposed file (e.g. commands, prediction_entrypoints...), because this module
# imports the xgboost library that users might not want to install (e.g. when using deep learning).
# This module should be imported within functions or classes definitions when required.

import logging
import os.path as osp

import numpy as np
import xgboost
from xgboost import XGBClassifier, XGBRegressor

from dataiku.core import dkujson
from dataiku.base.utils import package_is_at_least
from dataiku.doctor.utils.gpu_execution import XGBOOSTGpuCapability
from dataiku.doctor.utils.model_io import XGBOOST_BOOSTER_FILENAME
from dataiku.doctor.utils.model_io import XGBOOST_CLF_ATTRIBUTES_FILENAME

from sklearn.preprocessing import LabelEncoder

logger = logging.getLogger(__name__)


def load_xgboost_model(clf_attributes_path, booster_path):
    with open(clf_attributes_path, "r") as f:
        classifier_attributes = dkujson.load(f)

    if classifier_attributes["missing"] is None:
        classifier_attributes["missing"] = np.nan

    logger.info("Instantiating xgboost classifier with attributes: %s", classifier_attributes)

    if classifier_attributes["clf_class"] == "DkuXGBClassifier":
        new_clf = instantiate_xgb_classifier(
            classifier_attributes["n_estimators"],
            classifier_attributes["verbosity"],
            classifier_attributes["n_jobs"],
            classifier_attributes["scale_pos_weight"],
            classifier_attributes["base_score"],
            classifier_attributes["random_state"],
            classifier_attributes["missing"],
            classifier_attributes["tree_method"],
            classifier_attributes["class_weight"],
        )
        # XGBoost 2.0 made classes_ a property without setter, its value is set as `np.arange(self.n_classes_)`
        if not package_is_at_least(xgboost, "2.0"):
            new_clf.classes_ = np.array(classifier_attributes["classes"])
        new_clf.n_classes_ = classifier_attributes["n_classes"]

        # For xgboost<1.0.0 the required _le attribute is not handled by {XGBClassifier,XGBRegressor}.load_model 
        # so _le has to be stored and retrieved in classifier_attributes
        # For xgboost>=1.0.0 the _le attribute is handled by XGBModel.load_model
        if not package_is_at_least(xgboost, "1.0.0"):
            new_clf._le = _LabelEncoderJsonSerializer.from_json(classifier_attributes["_le"])

    elif classifier_attributes["clf_class"] == "DkuXGBRegressor":
        new_clf = instantiate_xgb_regressor(
            classifier_attributes["n_estimators"],
            classifier_attributes["verbosity"],
            classifier_attributes["n_jobs"],
            classifier_attributes["scale_pos_weight"],
            classifier_attributes["base_score"],
            classifier_attributes["random_state"],
            classifier_attributes["missing"],
            classifier_attributes["tree_method"],
            classifier_attributes.get("tweedie_variance_power",None)
        )
    else:
        raise RuntimeError("Unknown classifier type for XGBoost: {}".format(classifier_attributes["clf_class"]))

    # Set additional attributes to the classifier
    if classifier_attributes["objective"]:
        new_clf.objective = classifier_attributes["objective"]

    logger.info("Loading xgboost Booster...")
    new_clf.load_model(booster_path)

    if classifier_attributes["tree_method"] == "gpu_hist":
        new_clf.set_params(**{"predictor": "gpu_predictor"})
    return new_clf


def dump_xgboost_model(folder_context, clf):
    # First save the booster into its own file
    logger.info("Saving xgboost Booster...")
    with folder_context.get_file_path_to_write(XGBOOST_BOOSTER_FILENAME) as booster_file_path:
        clf.save_model(booster_file_path)

    # Dump the classifer attributes

    # `silent` argument was deprecated and replaced with `verbosity` after version 1.3.0
    verbosity = clf.verbosity if hasattr(clf, "verbosity") else clf.silent

    # `tree_method` became a classifier attribute after version 1.0.0
    tree_method = clf.tree_method if hasattr(clf, "tree_method") else clf.kwargs["tree_method"]

    missing = None if np.isnan(clf.missing) else clf.missing

    classifier_attributes = {
        "clf_class": clf.__class__.__name__,
        "n_estimators": clf.n_estimators,
        "verbosity": verbosity,
        "n_jobs": clf.n_jobs,
        "scale_pos_weight": clf.scale_pos_weight,
        "base_score": clf.base_score,
        "random_state": clf.random_state,
        "missing": missing,
        "tree_method": tree_method,
        "objective": clf.objective,
    }

    if hasattr(clf, "kwargs") and "tweedie_variance_power" in clf.kwargs:
        classifier_attributes["tweedie_variance_power"]= clf.kwargs["tweedie_variance_power"]

    if isinstance(clf, DkuXGBClassifier):
        if not hasattr(clf, "kwargs") or not clf.kwargs or not clf.kwargs.get("class_weight"):
            class_weight = {}
        else:
            class_weight = {int(k): float(v) for k, v in clf.kwargs["class_weight"].items()}
        classifier_attributes["class_weight"] = class_weight
        classifier_attributes["classes"] = clf.classes_.tolist()
        classifier_attributes["n_classes"] = clf.n_classes_
        # For xgboost<1.0.0 the required _le attribute is not handled by {XGBClassifier,XGBRegressor}.save_model 
        # so _le has to be stored and retrieved in classifier_attributes
        # For xgboost>=1.0.0 the _le attribute is handled by XGBModel.save_model
        if not package_is_at_least(xgboost, "1.0.0"):
            classifier_attributes["_le"] = _LabelEncoderJsonSerializer.to_json(clf._le)

    logger.info("Saving xgboost classifier attributes: %s", classifier_attributes)
    with folder_context.get_file_path_to_write(XGBOOST_CLF_ATTRIBUTES_FILENAME) as attributes_file_path:
        with open(attributes_file_path, "w") as attributes_file:
            dkujson.dump(attributes_file, classifier_attributes)


def instantiate_xgb_classifier(n_estimators, silent, n_jobs, scale_pos_weight, base_score, random_state, missing, tree_method, class_weight):
    if package_is_at_least(xgboost, "1.1.0"):
        # `silent` argument was deprecated and replaced with `verbosity`
        kwargs = {"verbosity": silent}
    else:
        kwargs = {"silent": silent}

    return DkuXGBClassifier(
        n_estimators=n_estimators,
        n_jobs=n_jobs,
        scale_pos_weight=scale_pos_weight,
        base_score=base_score,
        random_state=random_state,
        missing=missing,
        tree_method=tree_method,
        class_weight=class_weight,
        **kwargs
    )


def instantiate_xgb_regressor(n_estimators, silent, n_jobs, scale_pos_weight, base_score, random_state, missing, tree_method, tweedie_variance_power):
    if package_is_at_least(xgboost, "1.1.0"):
        # `silent` argument was deprecated and replaced with `verbosity`
        kwargs = {"verbosity": silent}
    else:
        kwargs = {"silent": silent}
    return DkuXGBRegressor(
        n_estimators=n_estimators,
        n_jobs=n_jobs,
        scale_pos_weight=scale_pos_weight,
        base_score=base_score,
        random_state=random_state,
        missing=missing,
        tree_method=tree_method,
        tweedie_variance_power=tweedie_variance_power,
        **kwargs
    )


def build_objective_string(objective):
    if objective == "reg_linear" and package_is_at_least(xgboost, "0.90"):
        objective = "reg_squarederror"

    return objective.replace("_", ":")


def expand_tree_method_for_xgboost(input_hp_space, gpu_config):
    # GPU execution only supports 'Histogram' tree mode, as well as 'Exact' before xgboost 1.0.0
    # There is a pre-train check for this in XGBoostMeta.java, but this is for any api based usage
    tree_method = input_hp_space['tree_method']
    use_gpu = XGBOOSTGpuCapability.should_use_gpu(gpu_config)

    if not use_gpu:
        return tree_method

    # On GPU
    invalid_tree_methods_for_gpu = {
        "auto": "Automatic",
        "approx": "Approximate"
    }

    if tree_method in invalid_tree_methods_for_gpu:
        raise Exception(
            """XGBoost Tree Method hyperparameter is set to {}, which does not support GPU execution.
               Please select {}, or disable GPU execution""".format(
                invalid_tree_methods_for_gpu[tree_method],
                "Histogram" if package_is_at_least(xgboost, "1.0.0") else "Histogram, or Exact",
        ))

    if tree_method == "exact":
        if package_is_at_least(xgboost, "1.0.0"):
            # Starting from version 1.0.0, XGBoost removed the gpu_exact tree method, we thus fall back to gpu_hist instead.
            logger.info("Starting from version 1.0.0, XGBoost removed support for the gpu_exact tree method. Falling back to gpu_hist.")
            return "gpu_hist"
        return "gpu_exact"

    return "gpu_hist"


class DkuXGBClassifier(XGBClassifier):

    def predict_proba(self, X, ntree_limit=None, validate_features=None, base_margin=None, iteration_range=None):
        """
        XGBoost version < 1.3.2 implements `def predict_proba(self, data, ...)`
        instead of `def predict_proba(self, X, ...)`. This causes a bug when
        using calibration and sklearn 0.24, because predict_proba is then
        called with the X as a named parameter: `predict_proba(X=X)`. The
        solution is to override predict_proba to make sure that the parameter's
        name is X.
        """
        if validate_features is None:
            # Default value has changed after 1.2.0: True before, False after
            validate_features = not package_is_at_least(xgboost, "1.2.0")
        if package_is_at_least(xgboost, "1.4.0"):
            kwargs = {"base_margin": base_margin, "iteration_range": iteration_range}
        elif package_is_at_least(xgboost, "1.0.0"):
            kwargs = {"base_margin": base_margin}
        else:
            kwargs = {}

        # XGBoost 2.0 removed ntree_limit from predict_proba()
        if not package_is_at_least(xgboost, "2.0"):
            kwargs["ntree_limit"] = ntree_limit
        else:
            if kwargs.get("iteration_range") is None and ntree_limit is not None and ntree_limit > 0:
                kwargs["iteration_range"] = (0, ntree_limit)

        return super(DkuXGBClassifier, self).predict_proba(X, validate_features=validate_features, **kwargs)

    def fit(self, X, y, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, sample_weight=None, xgb_model=None):
        class_weight = self.get_params().get("class_weight")
        if class_weight is not None:
            class_weight_arr = np.vectorize(class_weight.get)(y)
            if sample_weight is None:
                sample_weight = class_weight_arr
            else:
                sample_weight *= class_weight_arr
        try:
            # XGBoost 2.1 removed parameters from fit()
            if not package_is_at_least(xgboost, "2.1"):
                return super(DkuXGBClassifier, self).fit(
                    X, y,
                    eval_set=eval_set,
                    verbose=verbose,
                    sample_weight=sample_weight,
                    eval_metric=eval_metric,
                    early_stopping_rounds=early_stopping_rounds,
                )
            else:
                self.set_params(eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds)
                return super(DkuXGBClassifier, self).fit(
                    X, y,
                    eval_set=eval_set,
                    verbose=verbose,
                    sample_weight=sample_weight,
                )
        except Exception as e:
            message = str(e)
            if "GPU support" in message:
                logger.error(message)
                raise Exception("""Your code environment has an installation of XGBoost that does not support computations on GPUs. 
                                   To install XGBoost with GPU support, please refer to http://xgboost.readthedocs.io/en/latest/build.html#building-with-gpu-support
                                   \n\n""" + message)
            else:
                raise e


class DkuXGBRegressor(XGBRegressor):

    def fit(self, X, y, eval_set=None, eval_metric=None, early_stopping_rounds=None, verbose=True, sample_weight=None, xgb_model=None):
        try:
            # XGBoost 2.1 removed parameters from fit()
            if not package_is_at_least(xgboost, "2.1"):
                return super(DkuXGBRegressor, self).fit(
                    X, y,
                    eval_set=eval_set,
                    verbose=verbose,
                    sample_weight=sample_weight,
                    eval_metric=eval_metric,
                    early_stopping_rounds=early_stopping_rounds,
                )
            else:
                self.set_params(eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds)
                return super(DkuXGBRegressor, self).fit(
                    X, y,
                    eval_set=eval_set,
                    verbose=verbose,
                    sample_weight=sample_weight,
                )
        except Exception as e:
            message = str(e)
            if "GPU support" in message:
                logger.error(message)
                raise Exception("""Your code environment has an installation of XGBoost that does not support computations on GPUs. 
                                   To install XGBoost with GPU support, please refer to http://xgboost.readthedocs.io/en/latest/build.html#building-with-gpu-support
                                   \n\n""" + message)
            else:
                raise e


class _LabelEncoderJsonSerializer(object):
    """
    Label encoder JSON serialization methods.
    Copied from https://github.com/dmlc/xgboost/blob/v1.5.2/python-package/xgboost/compat.py#L69
    ONLY used for XGBoost version lower than 1.0.0
    """
    @staticmethod
    def to_json(label_encoder):
        """Returns a JSON compatible dictionary"""
        json_dict = {}
        for k, v in label_encoder.__dict__.items():
            if isinstance(v, np.ndarray):
                json_dict[k] = v.tolist()
            else:
                json_dict[k] = v
        return json_dict

    @staticmethod
    def from_json(json_dict):
        """Load the encoder back from a JSON compatible dict."""
        label_encoder = LabelEncoder()
        meta = {}
        for k, v in json_dict.items():
            if k == 'classes_':
                label_encoder.classes_ = np.array(v)
                continue
            meta[k] = v
        label_encoder.__dict__.update(meta)
        return label_encoder
