/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.AnalyseUtils;
import com.dataiku.dip.analysis.ml.prediction.guess.OutcomeStats;
import com.dataiku.dip.analysis.ml.prediction.guess.TabularPredictionGuesser;
import com.dataiku.dip.analysis.ml.shared.FeatureGuessUtils;
import com.dataiku.dip.analysis.model.GuessStatus;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.preprocessing.CausalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.utils.DKULogger;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class CausalPredictionGuesser
extends TabularPredictionGuesser<PredictionMLTask.CausalPredictionMLTask> {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis");

    public CausalPredictionGuesser(PredictionMLTask.CausalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    @Override
    protected void checkTargetColumn(boolean throwException) {
        super.checkTargetColumn(throwException);
        MemColumn column = this.table.columns.get(((PredictionMLTask.CausalPredictionMLTask)this.task).targetVariable);
        if (column != null && column.selectedType != null) {
            boolean hasNonNumericalValues = !FeatureGuessUtils.isNumerical(column);
            int cardinality = AnalyseUtils.getValueFreqs(this.table, column).size();
            if (PredictionMLTask.PredictionType.CAUSAL_REGRESSION.equals((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType) && hasNonNumericalValues) {
                this.throwOrAddMessage("Outcome column contains non-numerical values, incompatible with causal regression models", throwException);
            }
            if (PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION.equals((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType) && cardinality > 2) {
                this.throwOrAddMessage(String.format("Outcome column cardinality is greater than 2 (cardinality=%s), incompatible with causal binary classification models", cardinality), throwException);
            }
        }
    }

    private void checkTreatmentColumn(boolean throwException) {
        if (StringUtils.isBlank((String)((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable)) {
            this.throwOrAddMessage("No treatment variable", throwException);
        }
        if (!this.table.columns.containsKey(((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable)) {
            this.throwOrAddMessage("Dataset does not contain the treatment variable column '" + ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable + "'", throwException);
        }
    }

    private void checkNoSpecialRoleCollision() {
        if (((PredictionMLTask.CausalPredictionMLTask)this.task).targetVariable.equals(((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable)) {
            this.messages.add("Outcome variable cannot be the treatment variable '" + ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable + "'");
        }
    }

    @Override
    protected void checkAllFixableSettings(boolean throwException) {
        super.checkAllFixableSettings(throwException);
        this.checkTreatmentColumn(throwException);
        this.checkNoSpecialRoleCollision();
    }

    @Override
    protected void guessPredictionType(MemColumn column) {
        PredictionMLTask.PredictionType guessedType;
        if (column == null || column.selectedType == null) {
            throw new IllegalArgumentException("Invalid outcome variable column. Cannot guess prediction type.");
        }
        boolean isAllNumeric = FeatureGuessUtils.isNumerical(column);
        int cardinality = AnalyseUtils.getValueFreqs(this.table, column).size();
        logger.infoV("Guess prediction type: allNumeric=%s cardinality=%s", new Object[]{isAllNumeric, cardinality});
        if (cardinality < 2) {
            throw new IllegalArgumentException("All values of the outcome are equal. Try refreshing the sample.");
        }
        if (cardinality == 2) {
            guessedType = PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION;
        } else if (isAllNumeric) {
            guessedType = PredictionMLTask.PredictionType.CAUSAL_REGRESSION;
        } else {
            throw new IllegalStateException("Multiple non-numerical values of the outcome detected. Causal prediction does not support multiclass classification.");
        }
        ((PredictionMLTask.CausalPredictionMLTask)this.task).setPredictionType(guessedType);
    }

    @Override
    public FeaturePreprocessingParams guessSingleFeature(MemColumn column) {
        Optional<FeaturePreprocessingParams.Role> specialRole = this.getSpecialFeatureRole(column);
        if (specialRole.isPresent()) {
            return this.guessSpecialFeature(FeatureGuessUtils.isNumerical(column), specialRole.get());
        }
        return FeatureGuessUtils.guessSingleFeature(this.table, column, this.task);
    }

    @Override
    protected Optional<FeaturePreprocessingParams.Role> getSpecialFeatureRole(MemColumn column) {
        MemColumn targetColumn = this.table.getColumn(((PredictionMLTask.CausalPredictionMLTask)this.task).targetVariable);
        if (column == targetColumn) {
            return Optional.of(FeaturePreprocessingParams.Role.TARGET);
        }
        MemColumn treatmentColumn = this.table.getColumn(((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable);
        if (column == treatmentColumn) {
            return Optional.of(FeaturePreprocessingParams.Role.TREATMENT);
        }
        return Optional.empty();
    }

    private PredictionModelingParams initAlgorithmsParams(PredictionMLTask.CausalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        block5: {
            block4: {
                if (!keepExistingParams) break block4;
                params = task.modeling;
                params.causal_forest = new PredictionModelingParams.CausalForestHyperparameterSpace();
                switch (task.predictionType) {
                    case CAUSAL_BINARY_CLASSIFICATION: {
                        params.random_forest_classification = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.logistic_regression = new PredictionModelingParams.LogisticRegressionHyperparametersSpace();
                        break block5;
                    }
                    case CAUSAL_REGRESSION: {
                        params.random_forest_regression = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.ridge_regression = new PredictionModelingParams.RidgeRegressionHyperparametersSpace();
                        break block5;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                    }
                }
            }
            params = new PredictionModelingParams(task.predictionType, task.modeling);
            params.meta_learners = Collections.singleton(PredictionModelingParams.CausalMetaLearner.T_LEARNER);
        }
        return params;
    }

    @Override
    protected PredictionModelingParams guessAlgorithms(MemTable table, PredictionMLTask.CausalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams.RandomForestHyperparametersSpace rf;
        PredictionModelingParams params = this.initAlgorithmsParams(task, keepExistingParams);
        PredictionModelingParams.XGBoostHyperparametersSpace xgboost = params.xgboost;
        params.causal_forest.enabled = true;
        switch (task.predictionType) {
            case CAUSAL_REGRESSION: {
                rf = params.random_forest_regression;
                rf.enabled = true;
                rf.selection_mode = PredictionModelingParams.TreeSelectionMode.PROP;
                rf.max_feature_prop.setToSingleValueGrid(1.0);
                PredictionModelingParams.RidgeRegressionHyperparametersSpace linear = params.ridge_regression;
                linear.enabled = true;
                linear.alphaMode = PredictionModelingParams.RidgeSelectAlphaMode.AUTO;
                linear.alpha.updateValues(0.1, 1.0);
                PredictionModelingParams.LightGBMHyperParametersSpace lightgbm = params.lightgbm_regression;
                break;
            }
            case CAUSAL_BINARY_CLASSIFICATION: {
                rf = params.random_forest_classification;
                rf.enabled = true;
                rf.selection_mode = PredictionModelingParams.TreeSelectionMode.SQRT;
                PredictionModelingParams.LogisticRegressionHyperparametersSpace lr = params.logistic_regression;
                lr.enabled = true;
                lr.penalty.withValue("l1", true).withValue("l2", false);
                PredictionModelingParams.LightGBMHyperParametersSpace lightgbm = params.lightgbm_classification;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        rf.max_tree_depth.updateValues(4L, 7L);
        rf.min_samples_leaf.updateValues(10L, 25L);
        xgboost.enable_early_stopping = false;
        lightgbm.early_stopping = false;
        return params;
    }

    @Override
    public void changeTypeNoReguess(PredictionMLTask.PredictionType previousPredictionType, @Nullable GuessStatus previousGuessStatus) {
        super.changeTypeNoReguess(previousPredictionType, previousGuessStatus);
        logger.info((Object)"Guessing algorithms");
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling = PredictionModelingParams.fromPredictionModelingParamsNoEnabledAlgos(((PredictionMLTask.CausalPredictionMLTask)this.task).modeling, ((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType);
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling = this.guessAlgorithms(this.table, (PredictionMLTask.CausalPredictionMLTask)this.task, true);
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling.grid_search = true;
        boolean isRegression = PredictionMLTask.PredictionType.CAUSAL_REGRESSION.equals((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, true);
        if (isRegression) {
            ((PredictionMLTask.CausalPredictionMLTask)this.task).positiveClass = null;
        } else {
            this.guessPositiveClassValue();
        }
    }

    @Override
    public void changeTargetNoReguess(String previousTarget, @Nullable GuessStatus previousGuessStatus) {
        super.changeTargetNoReguess(previousTarget, previousGuessStatus);
        boolean isRegression = PredictionMLTask.PredictionType.CAUSAL_REGRESSION.equals((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, true);
        if (isRegression) {
            ((PredictionMLTask.CausalPredictionMLTask)this.task).positiveClass = null;
        } else {
            this.guessPositiveClassValue();
        }
    }

    @Override
    protected void guessAllSettingsWithFixedPredictionType(boolean throwException) {
        ((PredictionMLTask.CausalPredictionMLTask)this.task).splitParams = SplitParams.buildStd();
        this.guessControlValue();
        this.checkAllFixableSettings(throwException);
        logger.info((Object)"Guessing causal prediction");
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling = this.guessAlgorithms(this.table, (PredictionMLTask.CausalPredictionMLTask)this.task, false);
        this.setNIterRandom();
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling.metrics = new MetricParams();
        ((PredictionMLTask.CausalPredictionMLTask)this.task).modeling.metrics.evaluationMetric = MetricParams.EvaluationMetric.AUUC;
        ((PredictionMLTask.CausalPredictionMLTask)this.task).preprocessing = new CausalPredictionPreprocessingParams();
        ((PredictionMLTask.CausalPredictionMLTask)this.task).preprocessing.per_feature = new HashMap();
        for (String name : this.table.columns.keySet()) {
            this.guessFeature(name);
        }
        boolean isRegression = PredictionMLTask.PredictionType.CAUSAL_REGRESSION.equals((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType);
        this.setTargetRemapping(isRegression, throwException);
        if (isRegression) {
            ((PredictionMLTask.CausalPredictionMLTask)this.task).positiveClass = null;
        } else {
            this.guessPositiveClassValue();
        }
    }

    public void changeTreatmentNoReguess(String previousTreatment, @Nullable GuessStatus previousGuessStatus) {
        this.checkTreatmentColumn(true);
        this.retrievePreviousGuessStatusBooleans(previousGuessStatus);
        this.checkAllFixableSettings(false);
        this.guessFeature(previousTreatment);
        this.changeTreatmentNoReguess();
    }

    private void changeTreatmentNoReguess() {
        ((FeaturePreprocessingParams)((PredictionMLTask.CausalPredictionMLTask)this.task).preprocessing.per_feature.get((Object)((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable)).role = FeaturePreprocessingParams.Role.TREATMENT;
        this.guessControlValue();
    }

    public OutcomeStats getOutcomeStats(String controlValue, boolean dropMissingTreatmentValues) {
        if (((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType == PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION) {
            return OutcomeStats.BinaryClassification.getOutcomeStatsPerClass(this.table, (PredictionMLTask.CausalPredictionMLTask)this.task, controlValue, dropMissingTreatmentValues);
        }
        if (((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType == PredictionMLTask.PredictionType.CAUSAL_REGRESSION) {
            return OutcomeStats.Regression.getOutcomeStatsByInterquartileInterval(this.table, (PredictionMLTask.CausalPredictionMLTask)this.task, controlValue, dropMissingTreatmentValues);
        }
        logger.error((Object)("Unknown predictionType: " + ((PredictionMLTask.CausalPredictionMLTask)this.task).predictionType.toString()));
        return null;
    }

    private void guessPositiveClassValue() {
        for (PredictionPreprocessingParams.MappingValue mapping : ((PredictionMLTask.CausalPredictionMLTask)this.task).preprocessing.target_remapping) {
            String val = mapping.sourceValue;
            if (val == null || !(val = val.toLowerCase(Locale.ROOT).trim()).equals("1") && !val.equals("1.") && !val.equals("1.0") && !val.equals("true") && !val.equals("yes") && !val.startsWith("pos") && !val.equals("+")) continue;
            ((PredictionMLTask.CausalPredictionMLTask)this.task).positiveClass = mapping.sourceValue;
            return;
        }
        ((PredictionMLTask.CausalPredictionMLTask)this.task).positiveClass = ((PredictionPreprocessingParams.MappingValue)((PredictionMLTask.CausalPredictionMLTask)this.task).preprocessing.target_remapping.get((int)1)).sourceValue;
    }

    public Map<String, Long> getTreatmentColDistribution() {
        MemColumn treatment = this.table.column(((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentVariable);
        return this.table.rows.stream().map(row -> StringUtils.defaultIfBlank((String)row.get(treatment), (String)"")).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
    }

    private void guessControlValue() {
        long max = 0L;
        Map<String, Long> treatmentDistribution = this.getTreatmentColDistribution();
        ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentValues.clear();
        ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentValues.addAll(treatmentDistribution.keySet());
        ((PredictionMLTask.CausalPredictionMLTask)this.task).enableMultiTreatment = ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentValues.size() > 2;
        for (String val : ((PredictionMLTask.CausalPredictionMLTask)this.task).treatmentValues) {
            if (val == null) continue;
            String normval = val.toLowerCase(Locale.ROOT).trim();
            if (normval.equals("0") || normval.equals("0.") || normval.equals("0.0") || normval.equals("false") || normval.equals("no") || normval.equals("control") || normval.equals("-")) {
                ((PredictionMLTask.CausalPredictionMLTask)this.task).controlValue = val;
                return;
            }
            if (treatmentDistribution.get(val) < max) continue;
            max = treatmentDistribution.get(val);
            ((PredictionMLTask.CausalPredictionMLTask)this.task).controlValue = val;
        }
    }

    @Override
    public GuessStatus checkStatus() {
        if (((PredictionMLTask.CausalPredictionMLTask)this.task).envSelection.envMode != CodeEnvSelection.EnvMode.EXPLICIT_ENV) {
            this.messages.add("No code environment suitable to run causal prediction models available. Please ask your administrator to create one and/or grant you access to it.");
        }
        return super.checkStatus();
    }
}

