import torch
import pytorch_lightning as pl

from gluonts.torch import SimpleFeedForwardEstimator

from dataiku.doctor.timeseries.models.gluonts.torch.torch_base_estimator import DkuGluonTSTorchDeepLearningEstimator
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DkuSimpleFeedForwardEstimator(DkuGluonTSTorchDeepLearningEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            full_context,
            batch_size,
            epochs,
            auto_num_batches_per_epoch,
            num_batches_per_epoch,
            num_hidden_dimensions,
            weight_decay,
            seed,
            learning_rate=.001,
            context_length=1,
            distr_output="StudentTOutput",
            batch_normalization=False,
            gpu_config=get_default_gpu_config(),
            monthly_day_alignment=None,
    ):
        super(DkuSimpleFeedForwardEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            weight_decay=weight_decay,
            distr_output=distr_output,
            monthly_day_alignment=monthly_day_alignment,
        )
        self.num_hidden_dimensions = num_hidden_dimensions

        # Searchable parameters
        # Learning rate, context length, weight_decay and distr_output are in the parent class
        self.batch_normalization = batch_normalization

    def initialize(self, core_params, modeling_params):
        super(DkuSimpleFeedForwardEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["gluonts_torch_simple_feedforward_timeseries_params"]
        self.use_timeseries_identifiers_as_features = algo_params.get("use_timeseries_identifiers_as_features", False)
        self.seed = algo_params["seed"]

    def _get_estimator(self, train_data, identifier_cardinalities):
        if self.auto_num_batches_per_epoch:
            self.num_batches_per_epoch = self._compute_auto_num_batches_per_epoch(train_data)

        torch.manual_seed(self.seed)
        pl.seed_everything(self.seed, workers=True)

        return SimpleFeedForwardEstimator(
            prediction_length=self.prediction_length,
            context_length=self.context_length,
            hidden_dimensions=self.num_hidden_dimensions,
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
            distr_output=self.get_distr_output_class(self.distr_output),
            batch_norm=self.batch_normalization,
            batch_size=self.batch_size,
            num_batches_per_epoch=self.num_batches_per_epoch,
            trainer_kwargs={
                "max_epochs": self.epochs,
                "accelerator": self.accelerator,
                "devices": self.devices,
            }
        )

