/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.prediction.ClassicalPredictionParamsExpander;
import com.dataiku.dip.analysis.ml.prediction.PredictionMLTaskHandlingStrategy;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.Map;

public class ClassicalPredictionMLTaskHandlingStrategy
implements PredictionMLTaskHandlingStrategy {
    @Override
    public String getDSSMetricName(PredictionMLTask task) {
        boolean isPartitioned = task.isPartitioned();
        if (task.splitParams.kfold) {
            if (isPartitioned) {
                return "dku.ml.predictionTrain.stratPyKFoldTrain";
            }
            return "dku.ml.predictionTrain.pyKFoldTrain";
        }
        if (isPartitioned) {
            return "dku.ml.predictionTrain.strat.pyRegularNoSaveTrain";
        }
        return "dku.ml.predictionTrain.pyRegularNoSaveTrain";
    }

    @Override
    public void checkSplit(PredictionMLTask task, SplitDesc splitDesc) {
        if (task.isPartitioned()) {
            return;
        }
        if (!task.splitParams.kfold) {
            if (splitDesc.trainRows == 0L) {
                throw ErrorContext.iae((String)"Train set is empty. Please check your train & validation settings");
            }
            if (splitDesc.testRows == 0L) {
                throw ErrorContext.iae((String)"Test set is empty. Please check your train & validation settings");
            }
        }
    }

    @Override
    public String getPythonFunction(PredictionMLTask task) {
        if (task.splitParams.kfold) {
            return "train_prediction_kfold";
        }
        if (!task.isPartitioned() && task.backendType == MLTask.BackendType.KERAS) {
            return "train_prediction_keras";
        }
        return "train_prediction_models_nosave";
    }

    @Override
    public WorkSet expandPredictionParams(PredictionMLTask task, String sessionId) throws Exception {
        return new ClassicalPredictionParamsExpander((PredictionMLTask.ClassicalPredictionMLTask)task, sessionId).expand();
    }

    @Override
    public int getThreadCountToRun(PredictionMLTask task) {
        return Math.max(1, task.maxConcurrentModelTraining);
    }

    @Override
    public Map<String, String> getForcedEnvVars() {
        return Collections.emptyMap();
    }
}

