from dataiku.doctor.timeseries.models.neuralforecast.base_estimator import DKUNeuralforecastEstimator
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DKUNHITSEstimator(DKUNeuralforecastEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment=None,
            quantiles=None,
            random_state=1337,
            learning_rate=0.001,
            context_length=1,
            patience=-1,
            max_steps=1000,
            batch_size=32,
            gpu_config=get_default_gpu_config()
    ):
        super(DKUNHITSEstimator, self).__init__(
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment,
            quantiles,
            random_state,
            gpu_config
        )

        self.learning_rate = learning_rate
        self.context_length = context_length
        self.patience = patience
        self.max_steps = max_steps
        self.batch_size = batch_size

    @staticmethod
    def get_name():
        return "NHITS"

    def get_model(self):
        from neuralforecast.models import NHITS
        return NHITS(
            h=self.prediction_length,
            input_size=self.context_length,
            max_steps=self.max_steps,
            learning_rate=self.learning_rate,
            hist_exog_list=DKUNHITSEstimator.get_unique_external_features(self.external_features)["INPUT_PAST_ONLY"] if self.external_features else None,
            futr_exog_list=DKUNHITSEstimator.get_unique_external_features(self.external_features)["INPUT"] if self.external_features else None,
            random_seed=self.random_state,
            **self.get_trainer_kwargs()
        )