import numpy as np
import logging
from statsmodels.tsa.api import STLForecast
from statsmodels.tsa.exponential_smoothing.ets import ETSModel

from dataiku.doctor.timeseries.models.statistical.base_estimator import DkuStatisticalEstimator, \
    INFORMATION_CRITERION_TO_DISPLAY_NAME
from dataiku.doctor.timeseries.models.statistical.stats import sanitized_stats_value, build_coefficient_dict

# TODO @timeseries check from statsmodels.tsa.statespace.exponential_smoothing import ExponentialSmoothing

logger = logging.getLogger(__name__)


class DkuSeasonalTrendLoessEstimator(DkuStatisticalEstimator):
    def __init__(
        self,
        frequency,
        time_variable,
        prediction_length,
        target_variable,
        timeseries_identifier_columns,
        auto_trend,
        auto_low_pass,
        period=2,
        seasonal=7,
        trend=3,
        low_pass=3,
        seasonal_deg=1,
        trend_deg=1,
        low_pass_deg=1,
        seasonal_jump=1,
        trend_jump=1,
        low_pass_jump=1,
        monthly_day_alignment=None,
    ):
        super(DkuSeasonalTrendLoessEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.auto_trend = auto_trend
        self.auto_low_pass = auto_low_pass

        # searchable parameters
        self.period = period
        self.seasonal = seasonal
        self.trend = trend
        self.low_pass = low_pass
        self.seasonal_deg = seasonal_deg
        self.trend_deg = trend_deg
        self.low_pass_deg = low_pass_deg
        self.seasonal_jump = seasonal_jump
        self.trend_jump = trend_jump
        self.low_pass_jump = low_pass_jump

    def initialize(self, core_params, modeling_params):
        super(DkuSeasonalTrendLoessEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["seasonal_loess_timeseries_params"]

        self.period = algo_params["period"]
        self.seasonal = algo_params["seasonal"]
        self.trend = algo_params["trend"]
        self.low_pass = algo_params["low_pass"]
        self.seasonal_deg = algo_params["seasonal_deg"]
        self.trend_deg = algo_params["trend_deg"]
        self.low_pass_deg = algo_params["low_pass_deg"]
        self.seasonal_jump = algo_params["seasonal_jump"]
        self.trend_jump = algo_params["trend_jump"]
        self.low_pass_jump = algo_params["low_pass_jump"]

    def set_params(self, **params):
        super(DkuSeasonalTrendLoessEstimator, self).set_params(**params)

        if self.auto_trend:
            self.trend = None
        if self.auto_low_pass:
            self.low_pass = None

        return self

    def _fit_single(self, target_values, date_values=None, external_features_values=None):
        """Fit one time series"""

        trained_model = STLForecast(
            endog=target_values,
            model=ETSModel,
            model_kwargs={"trend": "add"},
            period=self.period,
            seasonal=self.seasonal,
            trend=self.trend,
            low_pass=self.low_pass,
            seasonal_deg=int(self.seasonal_deg),
            trend_deg=int(self.trend_deg),
            low_pass_deg=int(self.low_pass_deg),
            seasonal_jump=self.seasonal_jump,
            trend_jump=self.trend_jump,
            low_pass_jump=self.low_pass_jump,
        ).fit()

        return trained_model

    def _forecast_single_timeseries(
        self,
        trained_model,
        past_target_values,
        past_date_values,
        quantiles,
        past_external_features_values,
        future_external_features_values,
        fit_before_predict,
        prediction_length
    ):
        if fit_before_predict:
            # instantiate and fit an STL model with the same hyperparameters as the trained model used during training
            trained_model = self._fit_single(past_target_values)
        else:
            if not np.array_equal(trained_model._endog, past_target_values):
                logger.warning(
                    "Predicting a Seasonal Trend model with a different target than the one used during training"
                )

        prediction_results = trained_model.get_prediction(
            start=trained_model.model.nobs,
            end=trained_model.model.nobs + prediction_length - 1,
        )

        return self._build_forecasts_dict(prediction_results, quantiles)

    def get_coefficients_map_and_names(self):
        coefficients_map = {}
        fixed_coefficients = ["smoothing_level", "smoothing_trend", "initial_level", "initial_trend"]
        for timeseries_identifier, trained_model in self.trained_models.items():
            for coeff_name in fixed_coefficients:
                if coeff_name not in coefficients_map:
                    coefficients_map[coeff_name] = build_coefficient_dict()
                coefficients_map[coeff_name]["values"][timeseries_identifier] = getattr(trained_model.model_result, coeff_name)
                try:
                    param_key = trained_model.model_result.param_names.index(coeff_name)
                    coefficients_map[coeff_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.model_result.pvalues[param_key])
                    coefficients_map[coeff_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.model_result.tvalues[param_key])
                    coefficients_map[coeff_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.model_result.bse[param_key])
                except (ValueError, KeyError) as exc:
                    logger.warning("failed to retrieve statistical value", exc)

        return coefficients_map, fixed_coefficients, None, None

    def get_information_criteria(self):
        information_criteria = []
        for criterion_name in ["aic", "bic", "hqic", "llf"]:
            criterion = { "values": {}, "displayName": INFORMATION_CRITERION_TO_DISPLAY_NAME[criterion_name] }
            for timeseries_identifier, trained_model in self.trained_models.items():
                item = getattr(trained_model.model_result, criterion_name)
                criterion["values"][timeseries_identifier] = self.prepare_information_criteria(item)
            information_criteria.append(criterion)
        return information_criteria

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        For SeasonalTrend models, fitted_values and residuals are computed/stored at train time.
        """
        trained_model = self.trained_models[identifier]
        return trained_model.model_result.fittedvalues, trained_model.model_result.resid