/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class SeasonalLoessMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Seasonal trend";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.SeasonalLoessSpace space = rpmp.seasonal_loess_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Season length", space.period).withMVParam("Seasonal smoother length", space.seasonal);
        if (!space.auto_trend) {
            description = description.withMVParam("Trend smoother length", space.trend);
        }
        if (!space.auto_low_pass) {
            description = description.withMVParam("Low pass length", space.low_pass);
        }
        description = description.withMVParam("Degree of seasonal LOESS", space.seasonal_deg).withMVParam("Degree of trend LOESS", space.trend_deg).withMVParam("Degree of low pass LOESS", space.low_pass_deg).withMVParam("Seasonal jump", space.seasonal_jump).withMVParam("Trend jump", space.trend_jump).withMVParam("Low pass jump", space.low_pass_jump);
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.SeasonalLoessParams params = after.seasonal_loess_timeseries_params;
        PredictionModelingParams.SeasonalLoessSpace space = before.seasonal_loess_timeseries_grid;
        ModelTrainInfo.PostSearchDescription description = new ModelTrainInfo.PostSearchDescription().withSVParam("Season length", params.period).withSVParam("Seasonal smoother length", params.seasonal);
        if (!space.auto_trend) {
            description = description.withSVParam("Trend smoother length", params.trend);
        }
        if (!space.auto_low_pass) {
            description = description.withSVParam("Low pass length", params.low_pass);
        }
        description = description.withSVParam("Degree of seasonal LOESS", params.seasonal_deg).withSVParam("Degree of trend LOESS", params.trend_deg).withSVParam("Degree of low pass LOESS", params.low_pass_deg).withSVParam("Seasonal jump", params.seasonal_jump).withSVParam("Trend jump", params.trend_jump).withSVParam("Low pass jump", params.low_pass_jump);
        return description;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.SeasonalLoessSpace space = pmp.seasonal_loess_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.seasonal_deg.getLength(), "At least one degree of seasonal LOESS must be selected (Seasonal trend)");
        checks.checkPositive(space.trend_deg.getLength(), "At least one degree of trend LOESS must be selected (Seasonal trend)");
        checks.checkPositive(space.low_pass_deg.getLength(), "At least one degree of low_pass LOESS must be selected (Seasonal trend)");
        checks.checkNumericalDimension(space.period, "Season length (Seasonal trend)");
        checks.checkNumericalDimension(space.seasonal, "Length of the seasonal smoother (Seasonal trend)");
        checks.checkNumericalDimension(space.trend, "Length of the trend smoother (Seasonal trend)");
        checks.checkNumericalDimension(space.low_pass, "Length of the low-pass filter (Seasonal trend)");
        checks.checkNumericalDimension(space.seasonal_jump, "Season linear interpolation step (Seasonal trend)");
        checks.checkNumericalDimension(space.trend_jump, "Trend linear interpolation step (Seasonal trend)");
        checks.checkNumericalDimension(space.low_pass_jump, "Low-pass linear interpolation step (Seasonal trend)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.SeasonalLoessSpace space = pmp.seasonal_loess_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SEASONAL_LOESS, pmp);
        preTrainParams.seasonal_loess_timeseries_grid = space;
        this.checkAndUpdateSearchStrategy(pmp, preTrainParams);
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.SeasonalLoessParams optimizedParams = optimized.seasonal_loess_timeseries_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.SeasonalLoessSpace space = preTrainParams.seasonal_loess_timeseries_grid;
        space.seasonal.setToSingleValueGrid(optimizedParams.seasonal);
        space.period.setToSingleValueGrid(optimizedParams.period);
        if (!preTrainParams.seasonal_loess_timeseries_grid.auto_trend) {
            space.trend.setToSingleValueGrid(optimizedParams.trend);
        }
        if (!preTrainParams.seasonal_loess_timeseries_grid.auto_low_pass) {
            space.trend.setToSingleValueGrid(optimizedParams.low_pass);
        }
        space.seasonal_deg = CategoricalHyperparameterDimension.create(optimizedParams.seasonal_deg, "0", "1");
        space.trend_deg = CategoricalHyperparameterDimension.create(optimizedParams.trend_deg, "0", "1");
        space.low_pass_deg = CategoricalHyperparameterDimension.create(optimizedParams.low_pass_deg, "0", "1");
        space.seasonal_jump.setToSingleValueGrid(optimizedParams.seasonal_jump);
        space.trend_jump.setToSingleValueGrid(optimizedParams.trend_jump);
        space.low_pass_jump.setToSingleValueGrid(optimizedParams.low_pass_jump);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.seasonal_loess_timeseries = preTrainParams.seasonal_loess_timeseries_grid;
        target.seasonal_loess_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.seasonal_loess_timeseries = usedToTrain.seasonal_loess_timeseries_grid;
        target.seasonal_loess_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.SeasonalLoessSpace seasonLoessSpace = (PredictionModelingParams.SeasonalLoessSpace)space;
        int gridLength = seasonLoessSpace.period.getLength() * seasonLoessSpace.seasonal.getLength() * seasonLoessSpace.seasonal_deg.getLength() * seasonLoessSpace.trend_deg.getLength() * seasonLoessSpace.low_pass_deg.getLength() * seasonLoessSpace.seasonal_jump.getLength() * seasonLoessSpace.trend_jump.getLength() * seasonLoessSpace.low_pass_jump.getLength();
        if (!seasonLoessSpace.auto_trend) {
            gridLength *= seasonLoessSpace.trend.getLength();
        }
        if (!seasonLoessSpace.auto_low_pass) {
            gridLength *= seasonLoessSpace.low_pass.getLength();
        }
        return gridLength;
    }

    @Override
    public PredictionModelingParams.GridSearchParams.Strategy getMaximumSupportedSearchStrategy() {
        return PredictionModelingParams.GridSearchParams.Strategy.RANDOM;
    }
}

