import torch
import pytorch_lightning as pl
from gluonts.torch import DeepAREstimator

from dataiku.doctor.timeseries.models.gluonts.torch.torch_base_estimator import DkuGluonTSTorchDeepLearningEstimator
from dataiku.doctor.timeseries.utils import FULL_TIMESERIES_DF_IDENTIFIER
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config, GluonTSTorchGPUCapability


class DkuDeepAREstimator(DkuGluonTSTorchDeepLearningEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            use_timeseries_identifiers_as_features,
            full_context,
            batch_size,
            epochs,
            auto_num_batches_per_epoch,
            num_batches_per_epoch,
            scaling,
            num_parallel_samples,
            weight_decay,
            patience,
            seed,
            learning_rate=.001,
            context_length=1,
            num_layers=2,
            num_cells=40,
            dropout_rate=0.1,
            distr_output="StudentTOutput",
            gpu_config=get_default_gpu_config(),
            monthly_day_alignment=None,
    ):
        super(DkuDeepAREstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            use_timeseries_identifiers_as_features=use_timeseries_identifiers_as_features,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            distr_output=distr_output,
            weight_decay=weight_decay,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.scaling = scaling
        self.num_parallel_samples = num_parallel_samples
        self.patience = patience

        # Searchable parameters
        # Learning rate & context length are in the parent class
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.dropout_rate = dropout_rate

    def initialize(self, core_params, modeling_params):
        super(DkuDeepAREstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["gluonts_torch_deepar_timeseries_params"]
        self.use_timeseries_identifiers_as_features = algo_params.get("use_timeseries_identifiers_as_features", False)
        self.seed = algo_params["seed"]

    def _get_estimator(self, train_data, identifier_cardinalities):
        time_parameters = self.get_time_based_parameters(parameters_to_set=["lags_seq", "time_features"])
        if self.auto_num_batches_per_epoch:
            self.num_batches_per_epoch = self._compute_auto_num_batches_per_epoch(train_data)

        torch.manual_seed(self.seed)
        pl.seed_everything(self.seed, workers=True)

        return DeepAREstimator(
            freq=self.frequency,
            prediction_length=self.prediction_length,
            context_length=self.context_length,
            num_layers=self.num_layers,
            hidden_size=self.num_cells,
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
            dropout_rate=self.dropout_rate,
            patience=self.patience,
            num_feat_dynamic_real=len(self.external_features[FULL_TIMESERIES_DF_IDENTIFIER]["INPUT"]) if self.external_features
                                                                                                and self.external_features[FULL_TIMESERIES_DF_IDENTIFIER]
                                                                                             else 0,
            # num_feat_static_real - available in lib but not in use in dss
            # num_feat_static_cat=len(identifier_cardinalities) if identifier_cardinalities else 0,
            # cardinality=identifier_cardinalities,
            # embedding_dimension - available in lib but not in use in dss
            distr_output=DkuGluonTSTorchDeepLearningEstimator.get_distr_output_class(self.distr_output),
            scaling=self.scaling,
            # default_scale - available in lib but not in use in dss
            lags_seq=time_parameters.get("lags_seq"),
            time_features=time_parameters.get("time_features"),
            num_parallel_samples=self.num_parallel_samples,
            batch_size=self.batch_size,
            num_batches_per_epoch=self.num_batches_per_epoch,
            trainer_kwargs={
                "max_epochs": self.epochs,
                "accelerator": self.accelerator,
                "devices": self.devices,
            }
        )
