/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class PyRandomForestMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public PyRandomForestMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Random forest";
    }

    @Override
    public boolean isShiftWindowsCompatible() {
        return true;
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.RandomForestHyperparametersSpace rfp;
        PredictionModelingParams.RandomForestHyperparametersSpace randomForestHyperparametersSpace = rfp = this.isClassification ? rpmp.rf_classifier_grid : rpmp.rf_regressor_grid;
        assert (rfp != null);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, rfp)).withMVParam("trees", rfp.n_estimators).withMVParam("depth", rfp.max_tree_depth).withMVParam("min_samples", rfp.min_samples_leaf);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("trees", after.rf.estimators).withSVParam("depth", after.rf.max_tree_depth).withSVParam("min_samples", after.rf.min_samples_leaf);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.RandomForestHyperparametersSpace rfp;
        PredictionModelingParams.RandomForestHyperparametersSpace randomForestHyperparametersSpace = rfp = this.isClassification ? pmp.random_forest_classification : pmp.random_forest_regression;
        if (rfp == null || !rfp.enabled) {
            return;
        }
        ErrorContext.check((rfp.n_jobs > 0 ? 1 : 0) != 0, (String)"Random forest parallelism must be > 0");
        checks.checkNumericalDimension(rfp.max_tree_depth, "Max depth of trees (Random Forest)");
        checks.checkNumericalDimension(rfp.min_samples_leaf, "Min samples per leaf (Random Forest)");
        checks.checkNumericalDimension(rfp.n_estimators, "Number of estimators (Random Forest)");
        if (rfp.selection_mode == PredictionModelingParams.TreeSelectionMode.NUMBER) {
            checks.checkNumericalDimension(rfp.max_features, "Max number of features (Random Forest)");
        }
        if (rfp.selection_mode == PredictionModelingParams.TreeSelectionMode.PROP) {
            checks.checkNumericalDimension(rfp.max_feature_prop, "Max proportion of features (Random Forest)");
        }
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.RandomForestHyperparametersSpace rfp = (PredictionModelingParams.RandomForestHyperparametersSpace)space;
        int baseNb = rfp.max_tree_depth.getLength() * rfp.min_samples_leaf.getLength() * rfp.n_estimators.getLength();
        if (rfp.selection_mode == PredictionModelingParams.TreeSelectionMode.NUMBER) {
            baseNb *= rfp.max_features.getLength();
        }
        if (rfp.selection_mode == PredictionModelingParams.TreeSelectionMode.PROP) {
            baseNb *= rfp.max_feature_prop.getLength();
        }
        return baseNb;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PreTrainPredictionModelingParams rcmp;
        PredictionModelingParams.RandomForestHyperparametersSpace rfp = this.isClassification ? pmp.random_forest_classification : pmp.random_forest_regression;
        ArrayList<WorkSet.ModelingSet> out = new ArrayList<WorkSet.ModelingSet>();
        if (rfp == null || !rfp.enabled) {
            return out;
        }
        PreTrainPredictionModelingParams preTrainPredictionModelingParams = rcmp = this.isClassification ? new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.RANDOM_FOREST_CLASSIFICATION, pmp) : new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.RANDOM_FOREST_REGRESSION, pmp);
        if (this.isClassification) {
            rcmp.rf_classifier_grid = rfp;
        } else {
            rcmp.rf_regressor_grid = rfp;
        }
        rcmp.max_ensemble_nodes_serialized = pmp.max_ensemble_nodes_serialized;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, rfp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        out.add(ms);
        return out;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return this.isClassification;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        switch (coreParams.prediction_type) {
            case REGRESSION: 
            case BINARY_CLASSIFICATION: {
                return true;
            }
        }
        return false;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        if (this.isClassification) {
            ret.rf_classifier_grid.n_estimators.setToSingleValueGrid(Long.valueOf(optimized.rf.estimators));
            ret.rf_classifier_grid.max_tree_depth.setToSingleValueGrid(Long.valueOf(optimized.rf.max_tree_depth));
            ret.rf_classifier_grid.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimized.rf.min_samples_leaf));
            ret.rf_classifier_grid.max_features.setToSingleValueGrid(Long.valueOf(optimized.rf.max_features));
            ret.rf_classifier_grid.max_feature_prop.setToSingleValueGrid(optimized.rf.max_feature_prop);
        } else {
            ret.rf_regressor_grid.n_estimators.setToSingleValueGrid(Long.valueOf(optimized.rf.estimators));
            ret.rf_regressor_grid.max_tree_depth.setToSingleValueGrid(Long.valueOf(optimized.rf.max_tree_depth));
            ret.rf_regressor_grid.min_samples_leaf.setToSingleValueGrid(Long.valueOf(optimized.rf.min_samples_leaf));
            ret.rf_regressor_grid.max_features.setToSingleValueGrid(Long.valueOf(optimized.rf.max_features));
            ret.rf_regressor_grid.max_feature_prop.setToSingleValueGrid(optimized.rf.max_feature_prop);
        }
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.random_forest_classification = preTrain.rf_classifier_grid;
            target.random_forest_classification.enabled = true;
        } else {
            target.random_forest_regression = preTrain.rf_regressor_grid;
            target.random_forest_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.random_forest_classification = usedToTrain.rf_classifier_grid;
            target.random_forest_classification.enabled = true;
        } else {
            target.random_forest_regression = usedToTrain.rf_regressor_grid;
            target.random_forest_regression.enabled = true;
        }
    }
}

