import logging
from copy import deepcopy

import numpy as np

from dataiku.base.utils import package_is_at_least
from dataiku.core import doctor_constants
from dataiku.doctor.prediction.common import TrainableModel

logger = logging.getLogger(__name__)


class LightGBMTrainableModel(TrainableModel):

    def __init__(self, estimator, hyperparameters_space,
                 is_early_stopping_enabled, early_stopping_rounds,
                 evaluation_metric_name, prediction_type):
        super(LightGBMTrainableModel, self).__init__(
            estimator=estimator,
            hyperparameters_space=hyperparameters_space,
            supports_sample_weights=True
        )

        self._is_early_stopping_enabled = is_early_stopping_enabled

        if self._is_early_stopping_enabled:
            logger.info("Early stopping is enabled")

            if early_stopping_rounds <= 0:
                raise Exception("The early stopping rounds must be a positive number")

            self._early_stopping_rounds = early_stopping_rounds
            self._evaluation_metric = _get_evaluation_metric(evaluation_metric_name, prediction_type)

        else:
            logger.info("Early stopping is disabled")
            self._early_stopping_rounds = None
            self._evaluation_metric = None

    @property
    def must_search(self):
        return self._is_early_stopping_enabled

    @property
    def requires_evaluation_set(self):
        return self._is_early_stopping_enabled

    def get_fit_parameters(self, sample_weight=None, X_eval=None, y_eval=None,
                           is_final_fit=False):
        fit_parameters = super(LightGBMTrainableModel, self) \
            .get_fit_parameters(sample_weight, X_eval, y_eval, is_final_fit)

        # Do not use early stopping for the final fit - only during HP search.
        if self._is_early_stopping_enabled and not is_final_fit:
            # LightGBM>=4 removed `early_stopping_rounds` from fit()
            import lightgbm
            if not package_is_at_least(lightgbm, "4"):
                fit_parameters["early_stopping_rounds"] = self._early_stopping_rounds
            else:
                if not fit_parameters.get("callbacks"):
                    fit_parameters["callbacks"] = []
                fit_parameters["callbacks"].append(lightgbm.early_stopping(self._early_stopping_rounds))
            fit_parameters["eval_metric"] = self._evaluation_metric

            if X_eval is not None and y_eval is not None:
                fit_parameters["eval_set"] = [(X_eval, y_eval)]

        return fit_parameters

    def get_extra_per_split_search_result_attributes(self, estimator):
        extra_attributes = {}

        if self._is_early_stopping_enabled:
            extra_attributes["best_iteration_"] = estimator.best_iteration_

        return extra_attributes

    def compute_model_parameters(self, per_split_search_results):
        # The model parameters are the same across all splits.
        model_parameters = deepcopy(per_split_search_results[0]["parameters"])

        if self._is_early_stopping_enabled:
            # When early stopping is enabled, we compute the median best iteration
            # to reuse in the final fit.
            per_split_best_iterations = [result["best_iteration_"] for result in per_split_search_results]
            median_best_iteration = int(np.median(per_split_best_iterations))

            # best_iteration starts at 0 while n_estimators starts at 1
            model_parameters["n_estimators"] = median_best_iteration + 1

        return model_parameters


def _get_evaluation_metric(evaluation_metric_name, prediction_type):
    logger.info("Get LightGBM metric for {}".format(evaluation_metric_name))

    available_metrics = {
        doctor_constants.BINARY_CLASSIFICATION: {
            "ACCURACY": "binary_error",
            "LOG_LOSS": "binary_logloss",
            "ROC_AUC": "auc"
        },
        doctor_constants.MULTICLASS: {
            "ACCURACY": "multi_error",
            "LOG_LOSS": "multi_logloss",
            "ROC_AUC": "auc_mu"
        },
        doctor_constants.REGRESSION: {
            "MAPE": "mape",
            "MAE": "mae",
            "MSE": "mse",
            "RMSE": "rmse",
        }
    }

    prediction_type_metrics = available_metrics.get(prediction_type, {})
    evaluation_metric = prediction_type_metrics.get(evaluation_metric_name)

    if evaluation_metric is None:
        logger.info("No metric found, default value will be used")
        return None

    logger.info("Found evaluation metric {}".format(evaluation_metric))
    return evaluation_metric
