/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.List;
import java.util.Set;

public class ArimaMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "ARIMA";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.ArimaSpace params = rpmp.arima_grid;
        ModelTrainInfo.PreSearchDescription preSearchDescription = new ModelTrainInfo.PreSearchDescription(rpmp);
        preSearchDescription = preSearchDescription.withSVParam("p", params.p);
        preSearchDescription = preSearchDescription.withSVParam("d", params.d);
        preSearchDescription = preSearchDescription.withSVParam("q", params.q);
        preSearchDescription = preSearchDescription.withSVParam("P", params.P);
        preSearchDescription = preSearchDescription.withSVParam("D", params.D);
        preSearchDescription = preSearchDescription.withSVParam("Q", params.Q);
        preSearchDescription = preSearchDescription.withSVParam("Season length", params.s);
        preSearchDescription = preSearchDescription.withSVParam("Trend", params.trend);
        preSearchDescription = preSearchDescription.withSVParam("Trend offset", params.trend_offset);
        preSearchDescription = preSearchDescription.withSVParam("Enforce stationarity", params.enforce_stationarity);
        preSearchDescription = preSearchDescription.withSVParam("Enforce invertibility", params.enforce_invertibility);
        preSearchDescription = preSearchDescription.withSVParam("Concentrate scale", params.concentrate_scale);
        return preSearchDescription;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.ArimaParams params = after.arima_timeseries_params;
        ModelTrainInfo.PostSearchDescription postSearchDescription = new ModelTrainInfo.PostSearchDescription();
        postSearchDescription = postSearchDescription.withSVParam("p", params.p);
        postSearchDescription = postSearchDescription.withSVParam("d", params.d);
        postSearchDescription = postSearchDescription.withSVParam("q", params.q);
        postSearchDescription = postSearchDescription.withSVParam("P", params.P);
        postSearchDescription = postSearchDescription.withSVParam("D", params.D);
        postSearchDescription = postSearchDescription.withSVParam("Q", params.Q);
        postSearchDescription = postSearchDescription.withSVParam("Season length", params.s);
        postSearchDescription = postSearchDescription.withSVParam("Trend", params.trend);
        postSearchDescription = postSearchDescription.withSVParam("Trend offset", params.trend_offset);
        postSearchDescription = postSearchDescription.withSVParam("Enforce stationarity", params.enforce_stationarity);
        postSearchDescription = postSearchDescription.withSVParam("Enforce invertibility", params.enforce_invertibility);
        postSearchDescription = postSearchDescription.withSVParam("Concentrate scale", params.concentrate_scale);
        return postSearchDescription;
    }

    private boolean isTrendValid(String trend, int d, int D) {
        int sum = d + D;
        switch (trend) {
            case "n": {
                return true;
            }
            case "c": 
            case "ct": {
                return sum <= 0;
            }
            case "t": {
                return sum <= 1;
            }
        }
        return false;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.ArimaSpace space = pmp.arima_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkNonNegative(space.p, "Order value p must be positive");
        checks.checkNonNegative(space.d, "Order value d must be positive");
        checks.checkNonNegative(space.q, "Order value q must be positive");
        checks.checkNonNegative(space.P, "Seasonal order value P must be positive");
        checks.checkNonNegative(space.D, "Seasonal order value D must be positive");
        checks.checkNonNegative(space.Q, "Seasonal order value Q must be positive");
        checks.checkNonNegative(space.s - 2, "Season length s must be greater than 2");
        ErrorContext.check((boolean)Set.of("n", "c", "t", "ct").contains(space.trend), (String)"Trend must be in [\u2018n\u2019,\u2019c\u2019,\u2019t\u2019,\u2019ct\u2019]");
        ErrorContext.check((boolean)this.isTrendValid(space.trend, space.d, space.D), (String)"Trend shouldn't have coefficients of degree lower than d+D");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.ArimaSpace space = pmp.arima_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.ARIMA, pmp);
        preTrainParams.arima_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.arima_timeseries = preTrainParams.arima_grid;
        target.arima_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.arima_timeseries = usedToTrain.arima_grid;
        target.arima_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }
}

