/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.prediction.CausalPredictionParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.PredictionMLTaskHandlingStrategy;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.Map;

public class CausalPredictionMLTaskHandlingStrategy
implements PredictionMLTaskHandlingStrategy {
    public static final Map<String, String> FORCED_ENV_VARS = Collections.singletonMap("OPENBLAS_NUM_THREADS", "1");

    @Override
    public String getDSSMetricName(PredictionMLTask task) {
        assert (!task.splitParams.kfold);
        return "dku.ml.predictionTrain.causalPredictionTrain";
    }

    @Override
    public void checkSplit(PredictionMLTask task, SplitDesc splitDesc) {
        assert (!task.splitParams.kfold);
        if (splitDesc.trainRows == 0L) {
            throw ErrorContext.iae((String)"Train set is empty. Please check your train & validation settings");
        }
        if (splitDesc.testRows == 0L) {
            throw ErrorContext.iae((String)"Test set is empty. Please check your train & validation settings");
        }
    }

    @Override
    public String getPythonFunction(PredictionMLTask task) {
        return "train_causal_prediction";
    }

    @Override
    public WorkSet expandPredictionParams(PredictionMLTask task, String sessionId) throws Exception {
        return new CausalPredictionParamsExpander((PredictionMLTask.CausalPredictionMLTask)task, sessionId).expand();
    }

    @Override
    public int getThreadCountToRun(PredictionMLTask task) {
        return Math.max(1, task.maxConcurrentModelTraining);
    }

    @Override
    public Map<String, String> getForcedEnvVars() {
        return FORCED_ENV_VARS;
    }
}

