/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.spark;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class MLLibRandomForestMeta
extends MLLibAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Random Forest (MLLib)";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        return new ModelTrainInfo.PreSearchDescription(rpmp).withMVParam("trees", rpmp.mllib_rf_grid.num_trees).withMVParam("max_depth", rpmp.mllib_rf_grid.max_depth).withSVParam("strategy", (Object)rpmp.mllib_rf_grid.subset_strategy);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("trees", after.mllib_rf.num_trees).withSVParam("max_depth", after.mllib_rf.max_depth).withSVParam("strategy", (Object)before.mllib_rf_grid.subset_strategy);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.MLLibTreesEnsembleGridParams rfp = pmp.mllib_rf;
        if (rfp == null || !rfp.enabled) {
            return;
        }
        MLLibRandomForestMeta.checkMLLibRFParams(pmp.mllib_rf, task, checks, "Random Forest");
    }

    public static void checkMLLibRFParams(PredictionModelingParams.MLLibTreesEnsembleGridParams params, PredictionMLTask task, PredictionParameterChecks checks, String algo) {
        if (task.predictionType != PredictionMLTask.PredictionType.REGRESSION && params.impurity == PredictionModelingParams.MLLibRfImpurity.variance) {
            throw ErrorContext.iaef((String)"%s classifier can't use 'variance' impurity", (Object)algo, (Object[])new Object[0]);
        }
        checks.check(params.checkpoint_interval >= 1, algo + ": check point interval must be >= 1");
        checks.check(params.max_bins >= 2, algo + ": max bins must be >= 2");
        checks.checkNumericalDimension(params.max_depth, algo + "Max depth of trees (Random Forest)");
        checks.checkPositive(params.max_memory_mb, algo + ": max memory must be > 0");
        checks.checkNonNegative(params.min_info_gain, algo + ": min info gain must be >= 0");
        checks.check(params.min_instance_per_node >= 1, algo + ": min instance per node must be >= 1");
        checks.checkNumericalDimension(params.num_trees, algo + "Number of trees (Random Forest)");
        checks.check(params.seed != 0L, algo + ": seed must be != 0");
        checks.checkPositive(params.subsampling_rate, algo + ": subsampling rate must be > 0");
        checks.checkBetween(params.subsampling_rate, 0.0, 1.0, algo + ": subsampling rate must be between 0 and 1 strictly");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.MLLibTreesEnsembleGridParams rfp = pmp.mllib_rf;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (rfp == null || !rfp.enabled) {
            return ret;
        }
        if (task.predictionType == PredictionMLTask.PredictionType.REGRESSION) {
            rfp.impurity = PredictionModelingParams.MLLibRfImpurity.variance;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.MLLIB_RANDOM_FOREST, pmp);
        rcmp.max_ensemble_nodes_serialized = pmp.max_ensemble_nodes_serialized;
        rcmp.mllib_rf_grid = rfp;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = rfp.max_depth.getLength() * rfp.num_trees.getLength();
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        switch (coreParams.prediction_type) {
            case REGRESSION: 
            case BINARY_CLASSIFICATION: {
                return true;
            }
        }
        return false;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.mllib_rf_grid.max_depth.setToSingleValueGrid(Long.valueOf(optimized.mllib_rf.max_depth));
        ret.mllib_rf_grid.num_trees.setToSingleValueGrid(Long.valueOf(optimized.mllib_rf.num_trees));
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.mllib_rf = preTrain.mllib_rf_grid;
        target.mllib_rf.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.mllib_rf = usedToTrain.mllib_rf_grid;
        target.mllib_rf.enabled = true;
    }
}

