/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class TFTMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams usedToTrain) {
        return "TFT";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.TFTSpace space = rpmp.tft_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Learning rate", space.learning_rate).withSVParam("Max steps", space.max_steps).withSVParam("Patience steps", space.patience).withSVParam("Batch size", space.batch_size).withSVParam("Random state", space.random_state).withMVParam("Hidden size factor", space.hidden_size_factor).withSVParam("Max hidden size", space.max_hidden_size).withSVParam("Limit hidden size", space.limit_hidden_size).withMVParam("Context length", space.context_length).withMVParam("Number of LSTM layers", space.n_rnn_layers).withMVParam("Number of attention heads", space.n_head);
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.TFTParams params = after.tft_params;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("Learning rate", Float.valueOf(params.learning_rate)).withSVParam("Max steps", params.max_steps).withSVParam("Patience steps", params.patience).withSVParam("Batch size", params.batch_size).withSVParam("Random state", params.random_state).withSVParam("Hidden size", params.hidden_size).withSVParam("Hidden size factor", params.hidden_size_factor).withSVParam("Context length", params.context_length).withSVParam("Number of LSTM layers", params.n_rnn_layers).withSVParam("Number of attention heads", params.n_head);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.TFTSpace space = pmp.tft_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.max_steps, "Maximum steps number must be positive (TFT)");
        checks.checkPositive(space.batch_size, "Batch size must be positive (TFT)");
        checks.checkNumericalDimension(space.learning_rate, "Learning rate (TFT)");
        checks.checkNumericalDimension(space.context_length, "Context length (TFT)");
        checks.checkNumericalDimension(space.hidden_size_factor, "Hidden size factor (TFT)");
        checks.checkNumericalDimension(space.n_rnn_layers, "Number of RNN layers (TFT)");
        checks.checkNumericalDimension(space.n_head, "Number of attention heads (TFT)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.TFTSpace space = pmp.tft_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.TFT, pmp);
        preTrainParams.tft_timeseries_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.TFTParams optimizedParameters = optimized.tft_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.TFTSpace space = preTrainParams.tft_timeseries_grid;
        space.learning_rate.setToSingleValueGrid(Double.valueOf(optimizedParameters.learning_rate));
        space.context_length.setToSingleValueGrid(Long.valueOf(optimizedParameters.context_length));
        space.hidden_size_factor.setToSingleValueGrid(Long.valueOf(optimizedParameters.hidden_size_factor));
        space.n_rnn_layers.setToSingleValueGrid(Long.valueOf(optimizedParameters.n_rnn_layers));
        space.n_head.setToSingleValueGrid(Long.valueOf(optimizedParameters.n_head));
        space.max_steps = preTrainParams.tft_timeseries_grid.max_steps;
        space.batch_size = preTrainParams.tft_timeseries_grid.batch_size;
        space.patience = preTrainParams.tft_timeseries_grid.patience;
        space.max_hidden_size = preTrainParams.tft_timeseries_grid.max_hidden_size;
        space.limit_hidden_size = preTrainParams.tft_timeseries_grid.limit_hidden_size;
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.tft_timeseries = preTrainParams.tft_timeseries_grid;
        target.tft_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.tft_timeseries = usedToTrain.tft_timeseries_grid;
        target.tft_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams usedToTrain) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        int validHiddenSizeNumber;
        PredictionModelingParams.TFTSpace TFTSpace2 = (PredictionModelingParams.TFTSpace)space;
        if (TFTSpace2.limit_hidden_size) {
            validHiddenSizeNumber = 0;
            Long[] longArray = (Long[])TFTSpace2.n_head.values;
            int n = longArray.length;
            for (int i = 0; i < n; ++i) {
                long n_head = longArray[i];
                Long[] longArray2 = (Long[])TFTSpace2.hidden_size_factor.values;
                int n2 = longArray2.length;
                for (int j = 0; j < n2; ++j) {
                    long hidden_size_factor = longArray2[j];
                    if (n_head * hidden_size_factor > (long)TFTSpace2.max_hidden_size) continue;
                    ++validHiddenSizeNumber;
                }
            }
        } else {
            validHiddenSizeNumber = TFTSpace2.hidden_size_factor.getLength() * TFTSpace2.n_head.getLength();
        }
        return validHiddenSizeNumber * TFTSpace2.learning_rate.getLength() * TFTSpace2.context_length.getLength() * TFTSpace2.n_rnn_layers.getLength();
    }
}

