/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.ClassicalPredictionGuesser;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.datalayer.memimpl.MemTable;

public class PerformanceGuesser
extends ClassicalPredictionGuesser {
    public PerformanceGuesser(PredictionMLTask.ClassicalPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    @Override
    public PredictionModelingParams guessAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        switch (task.backendType) {
            case MLLIB: {
                return this.guessMLLibAlgorithms(task, keepExistingParams);
            }
            case H2O: {
                return this.guessH2OAlgorithms(task, keepExistingParams);
            }
            case PY_MEMORY: {
                return this.guessPythonAlgorithms(table, task, keepExistingParams);
            }
        }
        throw new IllegalArgumentException(String.valueOf((Object)task.backendType) + " not supported.");
    }

    private PredictionModelingParams initPythonAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        block5: {
            block4: {
                if (!keepExistingParams) break block4;
                params = task.modeling;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.random_forest_regression = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.lightgbm_regression = new PredictionModelingParams.LightGBMHyperParametersSpace();
                        break block5;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        params.random_forest_classification = new PredictionModelingParams.RandomForestHyperparametersSpace();
                        params.lightgbm_classification = new PredictionModelingParams.LightGBMHyperParametersSpace();
                        break block5;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                    }
                }
            }
            params = new PredictionModelingParams(task.predictionType, task.modeling);
            params.gridSearchParams.strategy = PredictionModelingParams.GridSearchParams.Strategy.RANDOM;
        }
        return params;
    }

    private PredictionModelingParams guessPythonAlgorithms(MemTable table, PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initPythonAlgorithmsParams(task, keepExistingParams);
        PredictionModelingParams.LightGBMHyperParametersSpace lightgbmHpSpace = switch (task.predictionType) {
            case PredictionMLTask.PredictionType.REGRESSION -> params.lightgbm_regression;
            case PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS -> params.lightgbm_classification;
            default -> throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)task.predictionType));
        };
        lightgbmHpSpace.enabled = true;
        lightgbmHpSpace.num_leaves.updateValues(31L, 255L);
        lightgbmHpSpace.learning_rate.updateValues(0.1, 0.25, 0.4);
        lightgbmHpSpace.colsample_bytree.updateValues(0.6, 0.85);
        lightgbmHpSpace.min_child_weight.updateValues(1.0);
        lightgbmHpSpace.use_bagging = true;
        lightgbmHpSpace.subsample = 0.75f;
        lightgbmHpSpace.subsample_freq = 2;
        PredictionModelingParams.RandomForestHyperparametersSpace rf = switch (task.predictionType) {
            case PredictionMLTask.PredictionType.REGRESSION -> params.random_forest_regression;
            case PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS -> params.random_forest_classification;
            default -> throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
        };
        int ncols = table.ncols();
        rf.enabled = true;
        rf.max_tree_depth.updateValues(6L, 10L, 10L + (long)Math.sqrt(ncols));
        rf.min_samples_leaf.updateValues(1L, 5L, 10L, 25L);
        rf.max_feature_prop.updateValues(0.2, 0.7);
        rf.selection_mode = PredictionModelingParams.TreeSelectionMode.PROP;
        switch (task.predictionType) {
            case REGRESSION: {
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                PredictionModelingParams.LogisticRegressionHyperparametersSpace lr = params.logistic_regression;
                lr.enabled = true;
                lr.penalty.withValue("l1", true).withValue("l2", false);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }

    private PredictionModelingParams initMLLibAlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        block5: {
            block4: {
                if (!keepExistingParams) break block4;
                params = task.modeling;
                switch (task.predictionType) {
                    case REGRESSION: {
                        params.mllib_linreg = new PredictionModelingParams.MLLibLinearRegressionGridParams();
                        params.mllib_rf = new PredictionModelingParams.MLLibTreesEnsembleGridParams();
                        params.mllib_gbt = new PredictionModelingParams.MLLibTreesEnsembleGridParams();
                        break block5;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        params.mllib_logit = new PredictionModelingParams.MLLibLogisticRegressionGridParams();
                        params.mllib_rf = new PredictionModelingParams.MLLibTreesEnsembleGridParams();
                        break block5;
                    }
                    default: {
                        throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
                    }
                }
            }
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    private PredictionModelingParams guessMLLibAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initMLLibAlgorithmsParams(task, keepExistingParams);
        switch (task.predictionType) {
            case REGRESSION: {
                params.mllib_linreg.enabled = true;
                params.mllib_linreg.reg_param.updateValues(0.03, 0.1, 0.3);
                params.mllib_rf.impurity = PredictionModelingParams.MLLibRfImpurity.variance;
                params.mllib_gbt.impurity = PredictionModelingParams.MLLibRfImpurity.variance;
                params.mllib_rf.enabled = true;
                params.mllib_rf.num_trees.updateValues(50L, 100L);
                break;
            }
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                params.mllib_logit.enabled = true;
                params.mllib_logit.reg_param.updateValues(0.03, 0.1, 0.3);
                params.mllib_rf.enabled = true;
                params.mllib_rf.num_trees.updateValues(50L, 100L);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        return params;
    }

    private PredictionModelingParams initH20AlgorithmsParams(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params;
        if (keepExistingParams) {
            params = task.modeling;
            params.glm_sparkling = new PredictionModelingParams.H2OGLMGridParams();
            params.gbm_sparkling = new PredictionModelingParams.H2OGBMGridParams();
            params.deep_learning_sparkling = new PredictionModelingParams.H2ODeepLearningGridParams();
            params.rf_sparkling = new PredictionModelingParams.H2ORandomForestGridParams();
        } else {
            params = new PredictionModelingParams(task.predictionType, task.modeling);
        }
        return params;
    }

    protected PredictionModelingParams guessH2OAlgorithms(PredictionMLTask.ClassicalPredictionMLTask task, boolean keepExistingParams) {
        PredictionModelingParams params = this.initH20AlgorithmsParams(task, keepExistingParams);
        switch (task.predictionType) {
            case REGRESSION: {
                params.glm_sparkling.family = "gaussian";
                params.gbm_sparkling.family = "gaussian";
                break;
            }
            case BINARY_CLASSIFICATION: {
                params.glm_sparkling.family = "binomial";
                params.gbm_sparkling.family = "bernoulli";
                break;
            }
            case MULTICLASS: {
                params.glm_sparkling.family = "multinomial";
                params.gbm_sparkling.family = "multinomial";
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type:" + String.valueOf((Object)task.predictionType));
            }
        }
        params.deep_learning_sparkling.enabled = true;
        params.glm_sparkling.enabled = true;
        params.rf_sparkling.enabled = true;
        params.rf_sparkling.ntrees = 100;
        return params;
    }
}

