import warnings

from statsmodels.tools.sm_exceptions import InterpolationWarning
from statsmodels.tsa.stattools import kpss
from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation
from dataiku.eda.exceptions import NotEnoughDataError, InvalidParams


class UnitRootTestKPSS(TimeSeriesComputation):
    def __init__(self, series_column, time_column, regression_mode, n_lags):
        super(UnitRootTestKPSS, self).__init__(series_column, time_column)
        self.regression_mode = regression_mode
        self.n_lags = n_lags

    @staticmethod
    def get_type() -> Literal["unit_root_kpss"]:
        return "unit_root_kpss"

    def describe(self):
        return "{}(series_column={}, time_column={}, regression_mode={}, n_lags={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.regression_mode,
            self.n_lags,
        )

    @staticmethod
    def build(params):
        return UnitRootTestKPSS(
            params["seriesColumn"],
            params["timeColumn"],
            params["regressionMode"],
            params.get("nLags"),
        )

    def _get_regression_param(self):
        params = {
            "CONSTANT": "c",
            "CONSTANT_WITH_TREND": "ct",
        }

        if self.regression_mode not in params:
            raise InvalidParams("Unknown regression mode")

        param = params[self.regression_mode]
        return param

    def apply(self, idf, ctx):
        regression = self._get_regression_param()
        series, timestamps = self._get_time_series(idf)
        n_obs = len(series)

        if n_obs < 2:
            raise NotEnoughDataError("At least 2 values are required in the series (current size: {})".format(n_obs))

        n_lags = self.n_lags
        if n_lags is None:
            n_lags = "auto"
        else:
            if n_lags <= 0:
                raise InvalidParams("n_lags must be greater than 0")

            if n_lags >= n_obs:
                raise InvalidParams("n_lags must be lower than the series size")

        warning_messages = []
        with warnings.catch_warnings(record=True) as caught_warnings:

            # statsmodels >= 0.11 in all Python 3 builtin envs since DSS 12
            # See https://github.com/statsmodels/statsmodels/commit/33cfbedcc63721ce8cac19a417342f3e644b902e
            statistic, p_value, used_lag, critical_values =\
                kpss(series, regression=regression, nlags=n_lags)

            for caught_warning in caught_warnings:
                # Such warnings are raised when the actual p value is smaller
                # or greater than the displayed one. In this case, we just
                # forward them to the user.
                if caught_warning.category == InterpolationWarning:
                    message = "{}".format(caught_warning.message)
                    warning_messages.append(message)

        return {
            "type": self.get_type(),
            "statistic": statistic,
            "pValue": p_value,
            "usedLag": used_lag,
            "nObservations": len(series),
            "criticalValues": {
                "1%": critical_values["1%"],
                "2.5%": critical_values["2.5%"],
                "5%": critical_values["5%"],
                "10%": critical_values["10%"],
            },
            "warnings": warning_messages,
        }
