import logging
from copy import deepcopy

import numpy as np

from dataiku.core import doctor_constants
from dataiku.doctor.prediction.common import TrainableModel

logger = logging.getLogger(__name__)


class XGBoostTrainableModel(TrainableModel):

    def __init__(self, estimator, hyperparameters_space,
                 is_early_stopping_enabled, early_stopping_rounds,
                 evaluation_metric_name, prediction_type):
        super(XGBoostTrainableModel, self).__init__(
            estimator=estimator,
            hyperparameters_space=hyperparameters_space,
            supports_sample_weights=True
        )

        self._is_early_stopping_enabled = is_early_stopping_enabled

        if self._is_early_stopping_enabled:
            logger.info("Early stopping is enabled")

            if early_stopping_rounds <= 0:
                # XGBoost does not fail gracefully in that case.
                raise Exception("The early stopping rounds must be a positive number")

            self._early_stopping_rounds = early_stopping_rounds
            self._evaluation_metric = _get_evaluation_metric(evaluation_metric_name, prediction_type)

        else:
            logger.info("Early stopping is disabled")
            self._early_stopping_rounds = None
            self._evaluation_metric = None

    @property
    def must_search(self):
        return self._is_early_stopping_enabled

    @property
    def requires_evaluation_set(self):
        return self._is_early_stopping_enabled

    def get_fit_parameters(self, sample_weight=None, X_eval=None, y_eval=None,
                           is_final_fit=False):
        fit_parameters = super(XGBoostTrainableModel, self)\
            .get_fit_parameters(sample_weight, X_eval, y_eval, is_final_fit)

        # Do not use early stopping for the final fit - only during HP search.
        if self._is_early_stopping_enabled and not is_final_fit:
            fit_parameters["early_stopping_rounds"] = self._early_stopping_rounds
            fit_parameters["eval_metric"] = self._evaluation_metric

            if X_eval is not None and y_eval is not None:
                fit_parameters["eval_set"] = [(X_eval, y_eval)]

        return fit_parameters

    def get_extra_per_split_search_result_attributes(self, estimator):
        extra_attributes = {}

        if self._is_early_stopping_enabled:
            extra_attributes["best_iteration"] = estimator.best_iteration

        return extra_attributes

    def compute_model_parameters(self, per_split_search_results):
        # The model parameters are the same across all splits.
        model_parameters = deepcopy(per_split_search_results[0]["parameters"])

        if self._is_early_stopping_enabled:
            # When early stopping is enabled, we compute the median best iteration
            # to reuse in the final fit.
            per_split_best_iterations = [result["best_iteration"] for result in per_split_search_results]
            median_best_iteration = int(np.median(per_split_best_iterations))

            # best_iteration starts at 0 while n_estimators starts at 1
            model_parameters["n_estimators"] = median_best_iteration + 1

        return model_parameters


def _get_evaluation_metric(evaluation_metric_name, prediction_type):
    logger.info("Get XGBoost metric for {}".format(evaluation_metric_name))

    default_evaluation_metric = _get_default_evaluation_metric(prediction_type)
    if evaluation_metric_name is None:
        return default_evaluation_metric

    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        # TODO @ml: Check if XGBoost's "map" could work for precision
        metrics_by_name = {
            "ACCURACY": "error",
            "LOG_LOSS": "logloss",
            "ROC_AUC": "auc"
        }

    elif prediction_type == doctor_constants.MULTICLASS:
        # TODO @ml: Check if we could use multiclass even for binary
        # Don't use auc as it does not work for multiclass
        metrics_by_name = {
            "ACCURACY": "merror",
            "LOG_LOSS": "mlogloss"
        }

    else:
        metrics_by_name = {
            "RMSE": "rmse",
            "MAE": "mae"
        }

    evaluation_metric = metrics_by_name.get(evaluation_metric_name, default_evaluation_metric)
    logger.info("Found evaluation metric {}".format(evaluation_metric))
    return evaluation_metric


def _get_default_evaluation_metric(prediction_type):
    """
    Maintain consistency of the default evaluation metric, since the library introduced
    a breaking change in https://github.com/dmlc/xgboost/pull/6183
    """
    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        return "error"
    elif prediction_type == doctor_constants.MULTICLASS:
        return "merror"
    else:
        return "rmse"
