import logging
import numpy as np
import pandas as pd
import sklearn

from sklearn.ensemble import RandomForestRegressor

from dataiku.base.utils import package_is_at_least

from dataiku.doctor.timeseries.models.base_estimator import BaseTimeseriesEstimator
from dataiku.doctor.timeseries.preparation.preprocessing import MultiHorizonShiftExpanderMixin
from dataiku.doctor.timeseries.utils import future_date_range
from dataiku.doctor.timeseries.utils import get_dataframe_of_timeseries_identifier
from dataiku.doctor.timeseries.utils import timeseries_iterator
from dataiku.doctor.timeseries.utils import ModelForecast

logger = logging.getLogger(__name__)


class DkuTimeSeriesMLEstimator(BaseTimeseriesEstimator, MultiHorizonShiftExpanderMixin):
    FEATURE_NAME_CHARACTER_REPLACEMENTS = {
        # Forbidden by XGBoost
        '[': '$$open_bracket$$',
        ']': '$$close_bracket$$',
        '<': '$$smaller_than$$',
    }


    def __init__(
            self,
            frequency,
            time_variable,
            prediction_length,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment=None,
    ):
        # params not passed directly to the external library model
        super(DkuTimeSeriesMLEstimator, self).__init__(
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment
        )

        self.shift_map = None
        self.supports_nan = False

        # dict of trained models by timeseries identifier
        self.trained_models = None

    def initialize(self, core_params, modeling_params):
        pass

    def fit(self, train_df, external_features=None, shift_map=None):
        if external_features is not None:
            self.external_features = external_features

        self.shift_map = shift_map

        self.trained_models = {}
        for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
                train_df, self.timeseries_identifiers
        ):
            target_values = df_of_timeseries_identifier[self.target_variable].reset_index(drop=True)
            external_features_values = (
                df_of_timeseries_identifier[self.external_features[timeseries_identifier]].reset_index(drop=True)
                if self.external_features and self.external_features.get(timeseries_identifier) else None
            )

            logger.info("Training models for time series {} with dataframe of raw shape {}".format(timeseries_identifier, df_of_timeseries_identifier.shape))

            self.trained_models[timeseries_identifier] = self._fit_single(
                target_values, external_features_values, timeseries_identifier
            )

    def _fit_single(self, target_values, external_features_values, timeseries_identifier):
        """
        :param target_values: pd.Series target are sorted by dates of increasing order
        :param external_features_values: pd.DataFrame
        :return:
        """
        clf_list = []

        for horizon in range(1, self.prediction_length + 1):
            clf = self.init_clf()
            X, y = self.expand_shifts_for_training(external_features_values, target_values, horizon, timeseries_identifier)
            logger.info("Training model for horizon t+{}. Final shape of train set after lags and windows is {}".format(horizon, X.shape))
            X = self.sanitize_feature_names(X)
            clf.fit(X, y)
            clf_list.append(clf)
        return clf_list

    def init_clf(self):
        """
        Responsible for passing hyperparameters defined in the companion class inheriting TimeseriesForecastingAlgorithm
        to all underlying clf's (one per step in the horizon).
        The mechanism is:
          - `TrainableModel.clone_estimator(parameters)` sets the `DkuTimeSeriesMLEstimator` subclass's attributes to
            values in `parameters` by calling sklearn's `BaseEstimator.set_params` method.
            Note: only attributes defined in the subclass's `__init__` method are set.
          - `init_clf` (this method) creates the underlying estimator `clf` and actually passes the attributes to `clf`.
        """
        raise NotImplementedError

    def predict(self, past_df, future_df, quantiles, fit_before_predict=False, prediction_length_override=None):
        forecasts_by_timeseries = {}
        for timeseries_identifier, past_df_of_timeseries_identifier in timeseries_iterator(
                past_df, self.timeseries_identifiers
        ):
            future_df_of_timeseries_identifier = None
            if self.external_features and self.external_features.get(timeseries_identifier):
                future_df_of_timeseries_identifier = get_dataframe_of_timeseries_identifier(
                    future_df, timeseries_identifier
                )

            forecasts_by_timeseries[timeseries_identifier] = self.predict_single(
                past_df_of_timeseries_identifier,
                future_df_of_timeseries_identifier,
                quantiles,
                timeseries_identifier,
                fit_before_predict=fit_before_predict,
                prediction_length_override=prediction_length_override
            )

        return forecasts_by_timeseries

    def predict_single(self, past_df, future_df, quantiles, timeseries_identifier, fit_before_predict=False, prediction_length_override=None):

        if prediction_length_override is not None and prediction_length_override > self.prediction_length:
            raise ValueError("Invalid prediction_length_override {} must be smaller or equal to prediction_length {}", prediction_length_override, self.prediction_length)

        prediction_length = self.prediction_length if prediction_length_override is None else prediction_length_override

        vals = []
        past_target_series = past_df[self.target_variable]

        has_external_features = self.external_features and self.external_features.get(timeseries_identifier)
        past_external_df = past_df[self.external_features[timeseries_identifier]] if has_external_features else None
        future_external_df = future_df[self.external_features[timeseries_identifier]] if (has_external_features and future_df is not None) else None

        for i, clf in enumerate(self.trained_models[timeseries_identifier][:prediction_length]):
            horizon = i+1
            X = self.expand_shifts_for_prediction(past_target_series, past_external_df, future_external_df, horizon, timeseries_identifier)
            X = self.sanitize_feature_names(X)
            pred = clf.predict(X)
            vals.append(pred[0])

        forecasts = {
            ModelForecast.FORECAST_VALUES: np.array(vals),
            ModelForecast.QUANTILES_FORECASTS: np.zeros(shape=(len(quantiles), prediction_length)),
        }

        last_past_date = past_df[self.time_variable].iloc[-1]
        forecasts[ModelForecast.TIMESTAMPS] = future_date_range(
            last_past_date,
            prediction_length,
            self.frequency,
            self.monthly_day_alignment,
        )
        return forecasts

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        clf = self.trained_models[identifier][0] # Residuals only for a single step ahead, like other models.
        n_rows = df_of_identifier.shape[0]
        n_predictions = n_rows - self.prediction_length

        predicted_values = np.full(n_rows, np.nan)
        ground_truth = np.full(n_rows, np.nan)

        target_values = df_of_identifier[self.target_variable]
        target_values.reset_index(drop=True, inplace=True)
        has_external_features = self.external_features and self.external_features.get(identifier)
        external_values = df_of_identifier[self.external_features[identifier]].values if has_external_features else None

        # Batch collect all valid expanded rows first
        rows_to_predict = []
        valid_indices = []

        columns = None

        for t in range(n_predictions):
            try:
                target_series = target_values[:t]
                past_external_df = pd.DataFrame(external_values[:t], columns=self.external_features[identifier]) if has_external_features else None
                future_external_df = pd.DataFrame(external_values[t+1:], columns=self.external_features[identifier]) if has_external_features else None

                expanded_row_to_predict = self.expand_shifts_for_prediction(target_series, past_external_df, future_external_df, 1, identifier)
                expanded_row_to_predict = self.sanitize_feature_names(expanded_row_to_predict)
                if self.supports_nan or not pd.isnull(expanded_row_to_predict).any().any() : # Don't try to predict rows with nan values for algos that don't support them

                    if columns is None:
                        columns = expanded_row_to_predict.columns
                    else:
                        assert columns.equals(expanded_row_to_predict.columns)

                    rows_to_predict.append(expanded_row_to_predict)
                    valid_indices.append(t)
                    ground_truth[t] = target_values[t+1]  # Assuming single-step ahead
            except Exception:
                # TODO: proper lags/windows bounds handling => min_scoring_size - for now the exception is a control flow for calling self.expand_shifts_for_prediction
                # for values of t where we shouldn't
                pass

        # Batch predict all at once
        if len(rows_to_predict) > 0:
            X_combined = np.vstack(rows_to_predict)
            X_df = pd.DataFrame(X_combined, columns=columns)
            predictions = clf.predict(X_df)
            predicted_values[valid_indices] = predictions

        predicted_values = pd.Series(predicted_values)
        residuals = pd.Series(ground_truth) - predicted_values
        return predicted_values, residuals

    @classmethod
    def sanitize_feature_names(cls, df):
        """Sanitizes DataFrame column names for compatibility with all algos.
        This function replaces special characters in column names that are forbidden by some algos.
        Each character is replaced with a specific text representation.

        Args:
           df (pd.DataFrame): The input DataFrame with potentially incompatible column names. It is updated by the method.
           :param estimator_classname: The classname of the estimator

        Returns:
           pd.DataFrame: The DataFrame with sanitized column names.
        """

        new_columns = df.columns.astype(str)
        for old_char, new_name in cls.FEATURE_NAME_CHARACTER_REPLACEMENTS.items():
            new_columns = new_columns.str.replace(old_char, new_name, regex=False)
        df.columns = new_columns
        return df



class DkuTimeSeriesLightGBMEstimator(DkuTimeSeriesMLEstimator):
    FEATURE_NAME_CHARACTER_REPLACEMENTS = {
        # Forbidden by lightgbm
        # https://github.com/microsoft/LightGBM/blob/master/include/LightGBM/utils/common.h
        # CheckAllowedJSON has the following
        # if (char_code == 34      // "
        #             || char_code == 44   // ,
        #                 || char_code == 58   // :
        #                 || char_code == 91   // [
        #                 || char_code == 93   // ]
        #                 || char_code == 123  // {
        #                 || char_code == 125  // }
        #             ) {
        '[': '$$open_bracket$$',
        ']': '$$close_bracket$$',
        '"': '$$double_quote$$',
        ',': '$$comma$$',
        ':': '$$colon$$',
        '{': '$$left_curly_bracket$$',
        '}': '$$right_curly_bracket$$'
    }

    def __init__(self,
                 frequency,
                 time_variable,
                 prediction_length,
                 target_variable,
                 timeseries_identifiers,
                 monthly_day_alignment=None,
                 boosting_type=None,
                 num_leaves=None,
                 max_depth=None,
                 learning_rate=None,
                 n_estimators=None,
                 subsample_for_bin=None,
                 objective=None,
                 min_split_gain=None,
                 min_child_weight=None,
                 min_child_samples=None,
                 subsample=None,
                 subsample_freq=None,
                 colsample_bytree=None,
                 reg_alpha=None,
                 reg_lambda=None,
                 random_state=1337,
                 n_jobs=4):
        super(DkuTimeSeriesLightGBMEstimator, self).__init__(frequency,
                                                                 time_variable,
                                                                 prediction_length,
                                                                 target_variable,
                                                                 timeseries_identifiers,
                                                                 monthly_day_alignment)

        self.boosting_type = boosting_type
        self.num_leaves = num_leaves
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.n_estimators = n_estimators
        self.subsample_for_bin = subsample_for_bin
        self.objective = objective
        self.min_split_gain = min_split_gain
        self.min_child_weight = min_child_weight
        self.min_child_samples = min_child_samples
        self.subsample = subsample
        self.subsample_freq = subsample_freq
        self.colsample_bytree = colsample_bytree
        self.reg_alpha = reg_alpha
        self.reg_lambda = reg_lambda
        self.random_state = random_state
        self.n_jobs = n_jobs


    def init_clf(self):
        from lightgbm import LGBMRegressor

        # first_metric_only is relevant only in the context of early stopping.
        # Setting this parameter to True means that LightGBM will only use the
        # metric we provide for evaluation and not the default metric for the
        # prediction task.
        first_metric_only = True

        clf = LGBMRegressor()
        clf.set_params(n_jobs=self.n_jobs,
                        random_state=self.random_state,
                        max_depth=self.max_depth,
                        first_metric_only=first_metric_only,
                        importance_type="gain",
                        boosting_type=self.boosting_type,
                        num_leaves=self.num_leaves,
                        learning_rate=self.learning_rate,
                        n_estimators=self.n_estimators,
                        subsample_for_bin=self.subsample_for_bin,
                        objective=self.objective,
                        min_split_gain=self.min_split_gain,
                        min_child_weight=self.min_child_weight,
                        min_child_samples=self.min_child_samples,
                        subsample=self.subsample,
                        subsample_freq=self.subsample_freq,
                        colsample_bytree=self.colsample_bytree,
                        reg_alpha=self.reg_alpha,
                        reg_lambda=self.reg_lambda)
        return clf

class DkuTimeSeriesRandomForestEstimator(DkuTimeSeriesMLEstimator):

    def __init__(self ,
                 frequency,
                 time_variable,
                 prediction_length,
                 target_variable,
                 timeseries_identifiers,
                 monthly_day_alignment=None,
                 min_samples_leaf=None,
                 n_estimators=None,
                 max_features=None,
                 max_depth=None,
                 min_samples_split=None,
                 random_state=1337,
                 n_jobs=4,
    ):

        super(DkuTimeSeriesRandomForestEstimator, self).__init__(frequency,
                                                           time_variable,
                                                           prediction_length,
                                                           target_variable,
                                                           timeseries_identifiers,
                                                           monthly_day_alignment)
        self.supports_nan = package_is_at_least(sklearn, "1.4")

        self.min_samples_leaf=min_samples_leaf
        self.n_estimators=n_estimators
        self.max_features=max_features
        self.max_depth=max_depth
        self.min_samples_split=min_samples_split
        self.random_state=random_state
        self.n_jobs=n_jobs

    def init_clf(self):
        clf = RandomForestRegressor()
        clf.set_params(min_samples_leaf=self.min_samples_leaf,
                       n_estimators=self.n_estimators,
                       max_features=self.max_features,
                       max_depth=self.max_depth,
                       min_samples_split=self.min_samples_split,
                       random_state=self.random_state,
                       n_jobs=self.n_jobs)
        return clf


class DkuTimeSeriesXGBoostEstimator(DkuTimeSeriesMLEstimator):

    def __init__(self ,
                 frequency,
                 time_variable,
                 prediction_length,
                 target_variable,
                 timeseries_identifiers,
                 monthly_day_alignment=None,
                 n_estimators=None,
                 scale_pos_weight=None,
                 base_score=None,
                 missing=None,
                 tree_method=None,
                 tweedie_variance_power=None,
                 random_state=1337,
                 n_jobs=4,
                 max_depth=None,
                 learning_rate=None,
                 gamma=None,
                 min_child_weight=None,
                 max_delta_step=None,
                 subsample=None,
                 colsample_bytree=None,
                 colsample_bylevel=None,
                 reg_alpha=None,
                 reg_lambda=None,
                 booster=None,
                 objective=None,
                 ):

        super(DkuTimeSeriesXGBoostEstimator, self).__init__(frequency,
                                                            time_variable,
                                                            prediction_length,
                                                            target_variable,
                                                            timeseries_identifiers,
                                                            monthly_day_alignment)
        self.supports_nan = True

        self.n_estimators = n_estimators
        self.n_jobs = n_jobs
        self.scale_pos_weight = scale_pos_weight
        self.base_score = base_score
        self.random_state = random_state
        self.missing = missing
        self.tree_method = tree_method
        self.tweedie_variance_power = tweedie_variance_power

        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.min_child_weight = min_child_weight
        self.max_delta_step = max_delta_step
        self.subsample = subsample
        self.colsample_bytree = colsample_bytree
        self.colsample_bylevel = colsample_bylevel
        self.reg_alpha = reg_alpha
        self.reg_lambda = reg_lambda
        self.booster = booster
        self.objective = objective

    def init_clf(self):
        from dataiku.doctor.prediction.dku_xgboost import DkuXGBRegressor # Avoid top level import of xgboost
        clf = DkuXGBRegressor()
        clf.set_params(n_estimators=self.n_estimators,
                       n_jobs=self.n_jobs,
                       scale_pos_weight=self.scale_pos_weight,
                       base_score=self.base_score,
                       random_state=self.random_state,
                       missing=self.missing,
                       tree_method=self.tree_method,
                       tweedie_variance_power=self.tweedie_variance_power,
                       max_depth = self.max_depth,
                       learning_rate = self.learning_rate,
                       gamma = self.gamma,
                       min_child_weight = self.min_child_weight,
                       max_delta_step = self.max_delta_step,
                       subsample = self.subsample,
                       colsample_bytree = self.colsample_bytree,
                       colsample_bylevel = self.colsample_bylevel,
                       reg_alpha = self.reg_alpha,
                       reg_lambda = self.reg_lambda,
                       booster = self.booster,
                       objective = self.objective
                       )
        return clf


class DkuTimeSeriesRidgeEstimator(DkuTimeSeriesMLEstimator):

    def __init__(self ,
                 frequency,
                 time_variable,
                 prediction_length,
                 target_variable,
                 timeseries_identifiers,
                 monthly_day_alignment=None,
                 alpha_mode=None,
                 alpha=None,):

        super(DkuTimeSeriesRidgeEstimator, self).__init__(frequency,
                                                          time_variable,
                                                          prediction_length,
                                                          target_variable,
                                                          timeseries_identifiers,
                                                          monthly_day_alignment)

        self.supports_nan = False

        self.alpha_mode=alpha_mode
        self.alpha=alpha

    def init_clf(self):
        from sklearn.linear_model import Ridge, RidgeCV

        if self.alpha_mode == "AUTO":
            clf = RidgeCV(fit_intercept=True)
        elif self.alpha_mode == "MANUAL":
            clf = Ridge(fit_intercept=True, random_state=1337)
            clf.set_params(alpha=self.alpha)
        else:
            raise ValueError("Invalid alpha_mode")
        return clf