# Keep on top, to handle missing libraries for MxNet (SC-98195)
from dataiku.doctor.timeseries.models.gluonts.mxnet.mxnet_base_estimator import DkuGluonTSMXNetDeepLearningEstimator
from dataiku.doctor.timeseries.utils.gluonts_compat import instantiate_simple_feed_forward_estimator
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DkuSimpleFeedForwardEstimator(DkuGluonTSMXNetDeepLearningEstimator):
    def __init__(
        self,
        frequency,
        prediction_length,
        time_variable,
        target_variable,
        timeseries_identifiers,
        full_context,
        batch_size,
        epochs,
        auto_num_batches_per_epoch,
        num_batches_per_epoch,
        num_hidden_dimensions,
        num_parallel_samples,
        seed,
        learning_rate=.001,
        context_length=1,
        distr_output="StudentTOutput",
        batch_normalization=False,
        mean_scaling=True,
        gpu_config=get_default_gpu_config(),
        monthly_day_alignment=None,
    ):
        super(DkuSimpleFeedForwardEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.num_hidden_dimensions = num_hidden_dimensions
        self.num_parallel_samples = num_parallel_samples

        # Searchable parameters
        # Learning rate & context length are in the parent class
        self.distr_output = distr_output
        self.batch_normalization = batch_normalization
        self.mean_scaling = mean_scaling

    def initialize(self, core_params, modeling_params):
        super(DkuSimpleFeedForwardEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["gluonts_simple_feedforward_timeseries_params"]
        self.use_timeseries_identifiers_as_features = algo_params.get("use_timeseries_identifiers_as_features", False)
        self.seed = algo_params["seed"]

    def _get_estimator(self, train_data, identifier_cardinalities):
        trainer = self.get_trainer(train_data)

        return instantiate_simple_feed_forward_estimator(
            trainer=trainer,
            frequency=self.frequency,
            prediction_length=self.prediction_length,
            batch_size=self.batch_size,
            context_length=self.context_length,
            num_hidden_dimensions=self.num_hidden_dimensions,
            num_parallel_samples=self.num_parallel_samples,
            batch_normalization=self.batch_normalization,
            mean_scaling=self.mean_scaling,
            distr_output=DkuGluonTSMXNetDeepLearningEstimator.get_distr_output_class(self.distr_output),
        )
