/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.CustomPythonPredictionAlgoService;
import com.dataiku.dip.analysis.ml.prediction.LoadedCustomPythonPredictionAlgoDesc;
import com.dataiku.dip.analysis.ml.prediction.guess.AnalyseUtils;
import com.dataiku.dip.analysis.ml.prediction.guess.TabularPredictionGuesser;
import com.dataiku.dip.analysis.ml.shared.FeatureGuessUtils;
import com.dataiku.dip.analysis.model.GuessStatus;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.TimeOrderingParams;
import com.dataiku.dip.analysis.model.prediction.WeightParams;
import com.dataiku.dip.analysis.model.prediction.assertions.MLAssertionsParams;
import com.dataiku.dip.analysis.model.prediction.overrides.MLOverridesParams;
import com.dataiku.dip.analysis.model.preprocessing.ClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TextFeaturePreprocessingParams;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.plugins.IPluginsRegistryService;
import com.dataiku.dip.plugins.PluginConfigUtils;
import com.dataiku.dip.plugins.model.PluginDesc;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public abstract class ClassicalPredictionGuesser
extends TabularPredictionGuesser<PredictionMLTask.ClassicalPredictionMLTask> {
    public static final int NB_LIMIT_TARGET_CLASS = 2;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis");

    protected ClassicalPredictionGuesser(PredictionMLTask.ClassicalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    @Override
    public TextFeaturePreprocessingParams guessText(MemColumn column) {
        TextFeaturePreprocessingParams ret = super.guessText(column);
        this.fixKerasInputsForTextFeature(ret);
        return ret;
    }

    @Override
    protected Optional<FeaturePreprocessingParams.Role> getSpecialFeatureRole(MemColumn column) {
        MemColumn targetColumn = this.table.getColumn(((PredictionMLTask.ClassicalPredictionMLTask)this.task).targetVariable);
        if (column == targetColumn) {
            return Optional.of(FeaturePreprocessingParams.Role.TARGET);
        }
        return Optional.empty();
    }

    private void fixKerasInputsForTextFeature(FeaturePreprocessingParams preprocessingParams) {
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType != MLTask.BackendType.KERAS) {
            return;
        }
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.keras.kerasInputs.contains(preprocessingParams.sendToInput)) {
            return;
        }
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.keras.kerasInputs.add(preprocessingParams.sendToInput);
    }

    @Override
    public void fixupAfterDuplication() {
        if (!StringUtils.isBlank((String)((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.sampleWeightVariable) && !this.table.columns.containsKey(((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.sampleWeightVariable)) {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.sampleWeightVariable = null;
        }
        super.fixupAfterDuplication();
    }

    @Override
    public FeaturePreprocessingParams guessSingleFeature(MemColumn column) {
        Optional<FeaturePreprocessingParams.Role> specialRole = this.getSpecialFeatureRole(column);
        if (specialRole.isPresent()) {
            return this.guessSpecialFeature(FeatureGuessUtils.isNumerical(column), specialRole.get());
        }
        FeaturePreprocessingParams preprocessingParams = FeatureGuessUtils.guessSingleFeature(this.table, column, this.task);
        if (preprocessingParams.type == FeaturePreprocessingParams.FeatureType.TEXT) {
            this.fixKerasInputsForTextFeature(preprocessingParams);
        }
        return preprocessingParams;
    }

    @Override
    protected void guessPredictionType(MemColumn column) {
        PredictionMLTask.PredictionType guessedType;
        if (column == null || column.selectedType == null) {
            throw new IllegalArgumentException("Invalid target variable column. Cannot guess prediction type.");
        }
        boolean isAllNumeric = FeatureGuessUtils.isNumerical(column);
        int cardinality = AnalyseUtils.getValueFreqs(this.table, column).size();
        logger.infoV("Guess prediction type: allNumeric=%s cardinality=%s", new Object[]{isAllNumeric, cardinality});
        if (cardinality < 2) {
            throw new IllegalArgumentException("All values of the target are equal. Try refreshing the sample.");
        }
        if (cardinality == 2) {
            guessedType = PredictionMLTask.PredictionType.BINARY_CLASSIFICATION;
        } else if (isAllNumeric && cardinality > 5) {
            guessedType = PredictionMLTask.PredictionType.REGRESSION;
        } else {
            int nbPossibleClazz = AnalyseUtils.getValuesPresentMoreThan(this.table, column, 2).size();
            logger.infoV("Guess prediction type: classes=%s", new Object[]{nbPossibleClazz});
            guessedType = nbPossibleClazz == 2 ? PredictionMLTask.PredictionType.BINARY_CLASSIFICATION : PredictionMLTask.PredictionType.MULTICLASS;
        }
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).setPredictionType(guessedType);
    }

    @Override
    protected void checkTargetColumn(boolean throwException) {
        super.checkTargetColumn(throwException);
        MemColumn column = this.table.columns.get(((PredictionMLTask.ClassicalPredictionMLTask)this.task).targetVariable);
        if (column != null && column.selectedType != null) {
            boolean hasNonNumericalValues = !FeatureGuessUtils.isNumerical(column);
            int cardinality = AnalyseUtils.getValueFreqs(this.table, column).size();
            if (PredictionMLTask.PredictionType.BINARY_CLASSIFICATION.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) && cardinality > 2) {
                this.messages.add(String.format("Target column cardinality is strictly greater than 2 (cardinality=%s), incompatible with binary classification models", cardinality));
            }
            if (PredictionMLTask.PredictionType.MULTICLASS.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) && cardinality == 2) {
                this.messages.add("Target column is binary (cardinality=2) in the sample, which is incompatible with multiclass models. Make sure the complete dataset contains more than two classes.");
            }
            if (PredictionMLTask.PredictionType.REGRESSION.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) && hasNonNumericalValues) {
                this.throwOrAddMessage("Target column contains non-numerical values, incompatible with regression models", throwException);
            }
        }
    }

    @Override
    protected void guessAllSettingsWithFixedPredictionType(boolean throwException) {
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.weightMethod = this.guessWeightMethod((PredictionMLTask)this.task);
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).splitParams = SplitParams.buildStd();
        if (MLTask.BackendType.MLLIB.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType) || MLTask.BackendType.H2O.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType)) {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).splitParams.ssdSelection = StreamableDatasetSelection.full();
        }
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).time = new TimeOrderingParams();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).assertionsParams = new MLAssertionsParams();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).overridesParams = new MLOverridesParams();
        this.checkAllFixableSettings(throwException);
        logger.info((Object)"Guessing algorithms");
        this.guessAlgorithmRelatedFields(false);
        this.setNIterRandom();
        this.guessMetrics();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).preprocessing = new ClassicalPredictionPreprocessingParams();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).preprocessing.per_feature = new HashMap();
        for (String name : this.table.columns.keySet()) {
            this.guessFeature(name);
        }
        boolean isRegression = PredictionMLTask.PredictionType.REGRESSION.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, throwException);
    }

    @Override
    public void changeTargetNoReguess(String previousTarget, @Nullable GuessStatus previousGuessStatus) {
        super.changeTargetNoReguess(previousTarget, previousGuessStatus);
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).assertionsParams = new MLAssertionsParams();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).overridesParams = new MLOverridesParams();
        boolean isRegression = PredictionMLTask.PredictionType.REGRESSION.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, true);
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.isSampleWeightEnabled() && !StringUtils.isBlank((String)((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.sampleWeightVariable) && ((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.sampleWeightVariable.equals(((PredictionMLTask.ClassicalPredictionMLTask)this.task).targetVariable)) {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.weightMethod = this.guessWeightMethod((PredictionMLTask)this.task);
        }
    }

    @Override
    public void changeTypeNoReguess(PredictionMLTask.PredictionType previousPredictionType, @Nullable GuessStatus previousGuessStatus) {
        super.changeTypeNoReguess(previousPredictionType, previousGuessStatus);
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).assertionsParams = new MLAssertionsParams();
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).overridesParams = new MLOverridesParams();
        boolean changePredictionTypeCategory = !((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType.category.equals((Object)previousPredictionType.category);
        boolean useClassWeight = EnumSet.of(WeightParams.WeightMethod.CLASS_WEIGHT, WeightParams.WeightMethod.CLASS_AND_SAMPLE_WEIGHT).contains((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.weightMethod);
        if (changePredictionTypeCategory && useClassWeight) {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).weight.weightMethod = this.guessWeightMethod((PredictionMLTask)this.task);
        }
        if (changePredictionTypeCategory && !MLTask.BackendType.KERAS.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType)) {
            logger.info((Object)"Guessing algorithms");
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling = PredictionModelingParams.fromPredictionModelingParamsNoEnabledAlgos(((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling, ((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType);
            this.guessAlgorithmRelatedFields(true);
        }
        this.guessMetrics();
        boolean isRegression = PredictionMLTask.PredictionType.REGRESSION.equals((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, true);
        this.fixUpAlgoDifferentPredType();
    }

    private void guessAlgorithmRelatedFields(boolean keepExistingParams) {
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling = this.guessAlgorithms(this.table, (PredictionMLTask.ClassicalPredictionMLTask)this.task, keepExistingParams);
        this.addPluginAlgorithms(!keepExistingParams);
        ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.grid_search = true;
    }

    public void fixUpAlgoDifferentPredType() {
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION) {
            if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType == MLTask.BackendType.MLLIB) {
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.mllib_naive_bayes.enabled = false;
            }
            if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType == MLTask.BackendType.PY_MEMORY) {
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.logistic_regression.multi_class = PredictionModelingParams.LogisticRegressionClassifierMultiClass.OVR;
            }
        } else if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType == PredictionMLTask.PredictionType.MULTICLASS) {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.mllib_gbt.enabled = false;
        }
        this.addPluginAlgorithms(false);
    }

    @Override
    public GuessStatus checkStatus() {
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType == PredictionMLTask.PredictionType.MULTICLASS && ((PredictionMLTask.ClassicalPredictionMLTask)this.task).getPreprocessingParams().target_remapping.size() > 50) {
            this.messages.add("A large number  (" + ((PredictionMLTask.ClassicalPredictionMLTask)this.task).getPreprocessingParams().target_remapping.size() + ") of classes has been detected. Training may fail, or performance may be very poor.");
        }
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType == MLTask.BackendType.KERAS && ((PredictionMLTask.ClassicalPredictionMLTask)this.task).envSelection.envMode != CodeEnvSelection.EnvMode.EXPLICIT_ENV) {
            this.messages.add("No code environment suitable to run deep learning models available. Please ask your administrator to create one and/or grant you access to it.");
        }
        return super.checkStatus();
    }

    private void guessMetrics() {
        switch (((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType) {
            case BINARY_CLASSIFICATION: {
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.metrics.evaluationMetric = MetricParams.EvaluationMetric.ROC_AUC;
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.metrics.thresholdOptimizationMetric = MetricParams.ThresholdOptimizationMetric.F1;
                break;
            }
            case REGRESSION: {
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.metrics.evaluationMetric = MetricParams.EvaluationMetric.R2;
                break;
            }
            case MULTICLASS: {
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.metrics.evaluationMetric = MetricParams.EvaluationMetric.ROC_AUC;
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.metrics.classAveragingMethod = MetricParams.ClassAveragingMethod.MACRO;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType));
            }
        }
    }

    private void addPluginAlgorithms(boolean redetect) {
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).backendType == MLTask.BackendType.PY_MEMORY) {
            this.addPythonPluginAlgorithms(redetect);
        }
    }

    private void addPythonPluginAlgorithms(boolean redetect) {
        CustomPythonPredictionAlgoService customPythonPredictionAlgoService = (CustomPythonPredictionAlgoService)SpringUtils.getBean(CustomPythonPredictionAlgoService.class);
        if (!redetect) {
            ArrayList<String> algosToRemove = new ArrayList<String>();
            for (Map.Entry<String, PredictionModelingParams.CustomPythonPluginParams> algo : ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.plugin_python.entrySet()) {
                PredictionModelingParams.CustomPythonPluginParams algoParams = algo.getValue();
                if (!customPythonPredictionAlgoService.exists(algoParams.pluginId, algoParams.elementId)) {
                    algosToRemove.add(algo.getKey());
                    continue;
                }
                LoadedCustomPythonPredictionAlgoDesc algoDesc = customPythonPredictionAlgoService.get(algoParams.pluginId, algoParams.elementId);
                if (algoDesc.desc.predictionTypes.contains((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType)) continue;
                algosToRemove.add(algo.getKey());
            }
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.plugin_python.keySet().removeAll(algosToRemove);
        } else {
            ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.plugin_python = new HashMap<String, PredictionModelingParams.CustomPythonPluginParams>();
        }
        for (LoadedCustomPythonPredictionAlgoDesc pythonPluginAlgo : customPythonPredictionAlgoService.listAlgosFromLoadedPlugins()) {
            String customModelId = pythonPluginAlgo.getType();
            if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.plugin_python.containsKey(customModelId) || !pythonPluginAlgo.desc.predictionTypes.contains((Object)((PredictionMLTask.ClassicalPredictionMLTask)this.task).predictionType)) continue;
            try {
                PluginDesc pluginDesc = ((IPluginsRegistryService)SpringUtils.getBean(IPluginsRegistryService.class)).getDesc(pythonPluginAlgo.ownerPluginId);
                PredictionModelingParams.CustomPythonPluginParams customPythonPluginParams = new PredictionModelingParams.CustomPythonPluginParams(pluginDesc, pythonPluginAlgo);
                customPythonPluginParams.params = PluginConfigUtils.setDefaultValues(pythonPluginAlgo.desc.params, customPythonPluginParams.params);
                ((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling.plugin_python.put(customModelId, customPythonPluginParams);
            }
            catch (Exception e) {
                logger.warn((Object)("Failed to load python plugin " + pythonPluginAlgo.ownerPluginId), (Throwable)e);
                this.messages.add("Failed to load python plugin " + pythonPluginAlgo.ownerPluginId + ". Please check the plugin config");
            }
        }
    }

    @Override
    public void updateGuess(@Nullable GuessStatus previousGuessStatus) {
        super.updateGuess(previousGuessStatus);
        if (((PredictionMLTask.ClassicalPredictionMLTask)this.task).modeling != null) {
            this.addPluginAlgorithms(false);
        }
    }

    private WeightParams.WeightMethod guessWeightMethod(PredictionMLTask predictionTask) {
        if (predictionTask.backendType != MLTask.BackendType.PY_MEMORY) {
            return WeightParams.WeightMethod.NO_WEIGHTING;
        }
        switch (predictionTask.predictionType) {
            case REGRESSION: {
                return WeightParams.WeightMethod.NO_WEIGHTING;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                return WeightParams.WeightMethod.CLASS_WEIGHT;
            }
        }
        throw new IllegalArgumentException("Prediction type should be either REGRESSION, BINARY_CLASSIFICATION or MULTICLASS");
    }
}

