from dataiku.doctor.timeseries.models.neuralforecast.base_estimator import DKUNeuralforecastEstimator
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DKUTFTEstimator(DKUNeuralforecastEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment=None,
            quantiles=None,
            random_state=1337,
            learning_rate=0.001,
            context_length=1,
            patience=-1,
            max_steps=1000,
            batch_size=32,
            hidden_size=128,
            hidden_size_factor=32,
            n_rnn_layers=1,
            n_head=4,
            gpu_config=get_default_gpu_config()
    ):
        super(DKUTFTEstimator, self).__init__(
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment,
            quantiles,
            random_state,
            gpu_config=gpu_config
        )

        self.learning_rate = learning_rate
        self.context_length = context_length
        self.patience = patience
        self.max_steps = max_steps
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.hidden_size_factor = hidden_size_factor
        self.n_rnn_layers = n_rnn_layers
        self.n_head = n_head

    @staticmethod
    def get_name():
        return "TFT"

    def get_model(self):
        from neuralforecast.models import TFT
        return TFT(
            h=self.prediction_length,
            input_size=self.context_length,
            max_steps=self.max_steps,
            hidden_size=self.hidden_size,
            n_rnn_layers=self.n_rnn_layers,
            n_head=self.n_head,
            learning_rate=self.learning_rate,
            hist_exog_list=DKUTFTEstimator.get_unique_external_features(self.external_features)["INPUT_PAST_ONLY"] if self.external_features else None,
            futr_exog_list=DKUTFTEstimator.get_unique_external_features(self.external_features)["INPUT"] if self.external_features else None,
            random_seed=self.random_state,
            **self.get_trainer_kwargs()
        )
