/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.GpuConfig;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;

public class XGBoostMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public XGBoostMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams ptpmp) {
        return "XGBoost";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams ptpmp) {
        PredictionModelingParams.XGBoostHyperparametersSpace xgb = ptpmp.xgboost_grid;
        assert (xgb != null);
        return new ModelTrainInfo.PreSearchDescription(ptpmp).withGridLength(this.getSearchSize(ptpmp.grid_search_params, xgb)).withSVParam("trees", xgb.n_estimators).withMVParam("max_depth", xgb.max_depth);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("trees", after.xgboost.n_estimators).withSVParam("max_depth", after.xgboost.max_depth);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.XGBoostHyperparametersSpace xgb = pmp.xgboost;
        if (xgb == null || !xgb.enabled) {
            return;
        }
        EnumSet<PredictionMLTask.PredictionType> causalTypes = EnumSet.of(PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.CAUSAL_REGRESSION);
        if (causalTypes.contains((Object)task.predictionType)) {
            ErrorContext.check((!xgb.enable_early_stopping ? 1 : 0) != 0, (String)"Early stopping cannot be activated for causal predictions");
        }
        if (task.gpuConfig.shouldUseGpu(GpuConfig.GpuSupportingCapability.XGBOOST)) {
            ErrorContext.check((xgb.tree_method.equals("hist") || xgb.tree_method.equals("exact") ? 1 : 0) != 0, (String)"XGBoost tree method must be 'Exact' or 'Histogram' to execute on a GPU. Either disable GPU, or change the selected tree method.");
        }
        ErrorContext.check((xgb.n_estimators > 0 ? 1 : 0) != 0, (String)"XGBoost maximum number of trees must be positive");
        ErrorContext.check((!xgb.enable_early_stopping || xgb.early_stopping_rounds >= 1 ? 1 : 0) != 0, (String)"XGBoost early stopping rounds parameter must be >= 1");
        checks.checkNumericalDimension(xgb.max_depth, "Max depth of trees (XGBoost) ");
        checks.checkNumericalDimension(xgb.learning_rate, "Learning rate (XGBoost)");
        checks.checkNumericalDimension(xgb.alpha, "Alpha L1 regularization parameter (XGBoost)");
        checks.checkNumericalDimension(xgb.lambda, "Lambda L2 regularization parameter (XGBoost)");
        checks.checkNumericalDimension(xgb.gamma, "Gamma regularization parameter (XGBoost)");
        checks.checkNumericalDimension(xgb.min_child_weight, "Minimum child weight (XGBoost)");
        checks.checkNumericalDimension(xgb.subsample, "Subsample rate (XGBoost)");
        checks.checkNumericalDimension(xgb.colsample_bytree, "Columns subsample rate (XGBoost)");
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.XGBoostHyperparametersSpace xgb = (PredictionModelingParams.XGBoostHyperparametersSpace)space;
        return xgb.max_depth.getLength() * xgb.learning_rate.getLength() * xgb.gamma.getLength() * xgb.min_child_weight.getLength() * xgb.max_delta_step.getLength() * xgb.subsample.getLength() * xgb.colsample_bytree.getLength() * xgb.colsample_bylevel.getLength() * xgb.alpha.getLength() * xgb.lambda.getLength() * xgb.booster.getLength() * xgb.objective.getLength();
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        boolean isRegressionTask = EnumSet.of(PredictionMLTask.PredictionType.REGRESSION, PredictionMLTask.PredictionType.CAUSAL_REGRESSION).contains((Object)task.predictionType);
        boolean isProperMetaForPredictionType = isRegressionTask != this.isClassification;
        PredictionModelingParams.XGBoostHyperparametersSpace xgb = pmp.xgboost;
        if (xgb == null || !xgb.enabled || !isProperMetaForPredictionType) {
            return new ArrayList<WorkSet.ModelingSet>();
        }
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        PreTrainPredictionModelingParams.Algorithm algo = this.isClassification ? PreTrainPredictionModelingParams.Algorithm.XGBOOST_CLASSIFICATION : PreTrainPredictionModelingParams.Algorithm.XGBOOST_REGRESSION;
        PreTrainPredictionModelingParams ptpmp = new PreTrainPredictionModelingParams(algo, pmp);
        ptpmp.xgboost_grid = xgb;
        ptpmp.gridLength = this.getSearchSize(ptpmp.grid_search_params, xgb);
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(ptpmp);
        boolean hasHPSearch = ptpmp.gridLength > 1 || xgb.enable_early_stopping;
        ms.estimatedTrains = hasHPSearch ? ptpmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return this.isClassification;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.xgboost_grid.max_depth.setToSingleValueGrid(Long.valueOf(optimized.xgboost.max_depth));
        ret.xgboost_grid.learning_rate.setToSingleValueGrid(Double.valueOf(optimized.xgboost.learning_rate));
        ret.xgboost_grid.gamma.setToSingleValueGrid(Double.valueOf(optimized.xgboost.gamma));
        ret.xgboost_grid.min_child_weight.setToSingleValueGrid(Double.valueOf(optimized.xgboost.min_child_weight));
        ret.xgboost_grid.max_delta_step.setToSingleValueGrid(Double.valueOf(optimized.xgboost.max_delta_step));
        ret.xgboost_grid.subsample.setToSingleValueGrid(Double.valueOf(optimized.xgboost.subsample));
        ret.xgboost_grid.colsample_bytree.setToSingleValueGrid(Double.valueOf(optimized.xgboost.colsample_bytree));
        ret.xgboost_grid.colsample_bylevel.setToSingleValueGrid(Double.valueOf(optimized.xgboost.colsample_bylevel));
        ret.xgboost_grid.alpha.setToSingleValueGrid(Double.valueOf(optimized.xgboost.alpha));
        ret.xgboost_grid.lambda.setToSingleValueGrid(Double.valueOf(optimized.xgboost.lambda));
        ret.xgboost_grid.booster = CategoricalHyperparameterDimension.create(optimized.xgboost.booster, "gbtree", "dart");
        ret.xgboost_grid.objective = CategoricalHyperparameterDimension.create(optimized.xgboost.objective, "reg_linear", "reg_logistic", "reg_gamma", "reg_tweedie", "count_poisson", "binary_logistic", "multi_softprob");
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.xgboost = preTrain.xgboost_grid;
            target.xgboost.enabled = true;
        } else {
            target.xgboost = preTrain.xgboost_grid;
            target.xgboost.enabled = true;
        }
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        switch (coreParams.prediction_type) {
            case REGRESSION: 
            case BINARY_CLASSIFICATION: {
                return true;
            }
        }
        return false;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.xgboost = usedToTrain.xgboost_grid;
            target.xgboost.enabled = true;
        } else {
            target.xgboost = usedToTrain.xgboost_grid;
            target.xgboost.enabled = true;
        }
    }
}

