from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation
from dataiku.eda.exceptions import NotEnoughDataError, InvalidParams
from dataiku.vendor.arch.unitroot import ZivotAndrews


class UnitRootTestZA(TimeSeriesComputation):
    def __init__(self, series_column, time_column, regression_mode, n_lags):
        super(UnitRootTestZA, self).__init__(series_column, time_column)
        self.regression_mode = regression_mode
        self.n_lags = n_lags

    @staticmethod
    def get_type() -> Literal["unit_root_za"]:
        return "unit_root_za"

    def describe(self):
        return "{}(series_column={}, time_column={}, regression_mode={}, n_lags={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.regression_mode,
            self.n_lags,
        )

    @staticmethod
    def build(params):
        return UnitRootTestZA(
            params["seriesColumn"],
            params["timeColumn"],
            params["regressionMode"],
            params.get("nLags"),
        )

    def _get_regression_param(self):
        params = {
            "CONSTANT_ONLY": "c",
            "TREND_ONLY": "t",
            "CONSTANT_WITH_TREND": "ct",
        }

        if self.regression_mode not in params:
            raise InvalidParams("Unknown regression mode")

        param = params[self.regression_mode]
        min_series_sizes = {
            "c": 4,
            "t": 4,
            "ct": 6,
        }

        return param, min_series_sizes[param]

    def apply(self, idf, ctx):
        regression, min_series_size = self._get_regression_param()
        series, timestamps = self._get_time_series(idf)

        n_obs = len(series)
        if n_obs < min_series_size:
            raise NotEnoughDataError(
                "At least {} values are required in the series for "
                "the requested regression mode (current size: {})".format(min_series_size, n_obs)
            )

        n_lags = self.n_lags
        if n_lags is not None:
            if n_lags <= 0:
                raise InvalidParams("n_lags must be greater than 0")

            if n_lags >= n_obs:
                raise InvalidParams("n_lags must be lower than the series size")

        result = ZivotAndrews(series, trend=regression, method="AIC", lags=n_lags)

        return {
            "type": self.get_type(),
            "statistic": result.stat,
            "pValue": result.pvalue,
            "usedLag": result.lags,
            "nObservations": result.nobs,
            "criticalValues": {
                "1%": result.critical_values["1%"],
                "5%": result.critical_values["5%"],
                "10%": result.critical_values["10%"],
            },
        }
