/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.List;

public class CrostonMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Croston";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.CrostonSpace space = rpmp.croston_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Variant", space.variant);
        if (space.variant.values.get((Object)"TSB").enabled) {
            description.withMVParam("Alpha_d", space.alpha_d).withMVParam("Alpha_p", space.alpha_p);
        }
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.CrostonParams params = after.croston_timeseries_params;
        ModelTrainInfo.PostSearchDescription description = new ModelTrainInfo.PostSearchDescription().withSVParam("Variant", params.variant);
        if ("TSB".equals(params.variant)) {
            description.withSVParam("Alpha_d", params.alpha_d).withSVParam("Alpha_p", params.alpha_p);
        }
        return description;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.CrostonSpace space = pmp.croston_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        ErrorContext.check((space.variant.getLength() > 0 ? 1 : 0) != 0, (String)"Croston requires at least one variant");
        if (space.variant.values.get((Object)"TSB").enabled) {
            checks.checkNumericalDimension(space.alpha_d, "Alpha_d parameter (Croston)");
            checks.checkNumericalDimension(space.alpha_p, "Alpha_p parameter (Croston)");
        }
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.CrostonSpace space = pmp.croston_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.CROSTON, pmp);
        preTrainParams.croston_timeseries_grid = space;
        this.checkAndUpdateSearchStrategy(pmp, preTrainParams);
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.CrostonParams optimizedParams = optimized.croston_timeseries_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.CrostonSpace space = preTrainParams.croston_timeseries_grid;
        space.variant = CategoricalHyperparameterDimension.create(optimizedParams.variant, "CLASSIC", "SBA", "TSB");
        space.alpha_d.setToSingleValueGrid(optimizedParams.alpha_d);
        space.alpha_p.setToSingleValueGrid(optimizedParams.alpha_p);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.croston_timeseries = preTrainParams.croston_timeseries_grid;
        target.croston_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.croston_timeseries = usedToTrain.croston_timeseries_grid;
        target.croston_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return false;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.CrostonSpace crostonSpace = (PredictionModelingParams.CrostonSpace)space;
        int gridLength = crostonSpace.variant.getLength();
        if (crostonSpace.variant.values.get((Object)"TSB").enabled) {
            gridLength += -1 + crostonSpace.alpha_d.getLength() * crostonSpace.alpha_p.getLength();
        }
        return gridLength;
    }
}

