# Keep on top, to handle missing libraries for MxNet (SC-98195)
from dataiku.doctor.timeseries.models.gluonts.base_estimator import DkuGluonTSEstimator

from dataiku.doctor.timeseries.utils.gluonts_compat import instantiate_identity_predictor


class DkuTrivialIdentityEstimator(DkuGluonTSEstimator):
    def __init__(self, frequency, prediction_length, time_variable, target_variable, timeseries_identifiers, monthly_day_alignment=None):
        super(DkuTrivialIdentityEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            monthly_day_alignment=monthly_day_alignment,
        )

    def fit(self, train_df, external_features=None, shift_map=None):
        self.predictor = instantiate_identity_predictor(
            frequency=self.frequency,
            prediction_length=self.prediction_length,
            num_samples=1
        )
        return self
