/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class NHITSMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams usedToTrain) {
        return "NHITS";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.NHITSSpace space = rpmp.nhits_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Learning rate", space.learning_rate).withSVParam("Max steps", space.max_steps).withSVParam("Patience steps", space.patience).withSVParam("Batch size", space.batch_size).withSVParam("Random seed", space.random_state).withMVParam("Context length", space.context_length);
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.NHITSParams params = after.nhits_params;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("Learning rate", Float.valueOf(params.learning_rate)).withSVParam("Max steps", params.max_steps).withSVParam("Patience steps", params.patience).withSVParam("Batch size", params.batch_size).withSVParam("Random seed", params.random_state).withSVParam("Context length", params.context_length);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.NHITSSpace space = pmp.nhits_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.max_steps, "Maximum steps number must be positive (NHITS)");
        checks.checkPositive(space.batch_size, "Batch size must be positive (NHITS)");
        checks.checkNumericalDimension(space.learning_rate, "Learning rate (NHITS)");
        checks.checkNumericalDimension(space.context_length, "Context length (NHITS)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.NHITSSpace space = pmp.nhits_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.NHITS, pmp);
        preTrainParams.nhits_timeseries_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.NHITSParams optimizedParameters = optimized.nhits_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.NHITSSpace space = preTrainParams.nhits_timeseries_grid;
        space.learning_rate.setToSingleValueGrid(Double.valueOf(optimizedParameters.learning_rate));
        space.context_length.setToSingleValueGrid(Long.valueOf(optimizedParameters.context_length));
        space.max_steps = preTrainParams.nhits_timeseries_grid.max_steps;
        space.batch_size = preTrainParams.nhits_timeseries_grid.batch_size;
        space.patience = preTrainParams.nhits_timeseries_grid.patience;
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.nhits_timeseries = preTrainParams.nhits_timeseries_grid;
        target.nhits_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.nhits_timeseries = usedToTrain.nhits_timeseries_grid;
        target.nhits_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams usedToTrain) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.NHITSSpace nhitsSpace = (PredictionModelingParams.NHITSSpace)space;
        return nhitsSpace.learning_rate.getLength() * nhitsSpace.context_length.getLength();
    }
}

