/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.coreservices.PredictionService;
import com.dataiku.dip.analysis.ml.prediction.PredictionMLTaskHandlingStrategy;
import com.dataiku.dip.analysis.ml.prediction.TimeseriesForecastingParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.Map;

public class TimeseriesForecastingMLTaskHandlingStrategy
implements PredictionMLTaskHandlingStrategy {
    public static final Map<String, String> FORCED_ENV_VARS = Collections.singletonMap("OMP_NUM_THREADS", "1");
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.prediction");

    @Override
    public String getDSSMetricName(PredictionMLTask task) {
        if (task.isPartitioned()) {
            return "dku.ml.predictionTrain.strat.timeseriesTrain";
        }
        return "dku.ml.predictionTrain.timeseriesTrain";
    }

    @Override
    public void checkSplit(PredictionMLTask task, SplitDesc splitDesc) {
        if (task.isPartitioned()) {
            return;
        }
        if (splitDesc.fullRows == 0L) {
            throw ErrorContext.iae((String)"Train set is empty. Please check your train & validation settings");
        }
    }

    @Override
    public String getPythonFunction(PredictionMLTask task) {
        return "train_prediction_timeseries";
    }

    @Override
    public WorkSet expandPredictionParams(PredictionMLTask task, String sessionId) throws Exception {
        PredictionMLTask.TimeseriesForecastingMLTask timeseriesForecastingMLTask = (PredictionMLTask.TimeseriesForecastingMLTask)task;
        if (!PredictionService.isHPSearchNeeded(timeseriesForecastingMLTask) && timeseriesForecastingMLTask.modeling.gridSearchParams.foldOffset) {
            logger.warn((Object)"No hyperparameter search space: disabling fold offset");
            timeseriesForecastingMLTask.modeling.gridSearchParams.foldOffset = false;
        }
        return new TimeseriesForecastingParamsExpander(timeseriesForecastingMLTask, sessionId).expand();
    }

    @Override
    public int getThreadCountToRun(PredictionMLTask task) {
        return Math.max(1, task.maxConcurrentModelTraining);
    }

    @Override
    public Map<String, String> getForcedEnvVars() {
        return FORCED_ENV_VARS;
    }
}

