/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;

public class LightGBMMeta
extends PyMemoryAlgorithmMeta {
    private final boolean isClassification;

    public LightGBMMeta(boolean isClassification) {
        this.isClassification = isClassification;
    }

    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "LightGBM";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.LightGBMHyperParametersSpace space = this.getHyperparametersSpace(rpmp);
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("boosting_type", space.boosting_type).withMVParam("n_estimators", space.n_estimators).withMVParam("num_leaves", space.num_leaves).withMVParam("learning_rate", space.learning_rate);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.LightGBMParams lightgbmParams = after.lightgbm;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("boosting_type", lightgbmParams.boosting_type).withSVParam("n_estimators", lightgbmParams.n_estimators).withSVParam("num_leaves", lightgbmParams.num_leaves).withSVParam("learning_rate", lightgbmParams.learning_rate);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.LightGBMHyperParametersSpace space = this.getHyperparametersSpace(pmp);
        if (!space.enabled) {
            return;
        }
        EnumSet<PredictionMLTask.PredictionType> causalTypes = EnumSet.of(PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.CAUSAL_REGRESSION);
        if (causalTypes.contains((Object)task.predictionType)) {
            ErrorContext.check((!space.early_stopping ? 1 : 0) != 0, (String)"Early stopping cannot be activated for causal predictions");
        }
        ErrorContext.check((space.early_stopping_rounds > 0 ? 1 : 0) != 0, (String)"Early stopping rounds (LightGBM) must be positive");
        checks.checkPositive(space.boosting_type.getLength(), "At least one boosting type (LightGBM) must be selected");
        ErrorContext.check((space.n_jobs >= -1 ? 1 : 0) != 0, (String)"Parallelism (LightGBM) must be either -1 or a positive number");
        ErrorContext.check((space.early_stopping_rounds > 0 ? 1 : 0) != 0, (String)"Early stopping rounds (LightGBM) must be positive");
        ErrorContext.check((space.subsample > 0.0f && space.subsample <= 1.0f ? 1 : 0) != 0, (String)"Bagging fraction (LightGBM) must be between 0 (excluded) and 1 (included)");
        checks.checkNumericalDimension(space.num_leaves, "Number of leaves (LightGBM)");
        checks.checkNumericalDimension(space.learning_rate, "Learning rate (LightGBM)");
        checks.checkNumericalDimension(space.n_estimators, "Number of estimators (LightGBM)");
        checks.checkNumericalDimension(space.min_split_gain, "Minimum split gain (LightGBM)");
        checks.checkNumericalDimension(space.min_child_weight, "Minimum child weight (LightGBM)");
        checks.checkNumericalDimension(space.min_child_samples, "Minimum leaf samples (LightGBM)");
        checks.checkNumericalDimension(space.colsample_bytree, "Colsample by tree (LightGBM)");
        checks.checkNumericalDimension(space.reg_alpha, "L1 regularization (LightGBM)");
        checks.checkNumericalDimension(space.reg_lambda, "L2 regularization (LightGBM)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.LightGBMHyperParametersSpace space = this.getHyperparametersSpace(pmp);
        if (!space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams.Algorithm algorithm = this.isClassification ? PreTrainPredictionModelingParams.Algorithm.LIGHTGBM_CLASSIFICATION : PreTrainPredictionModelingParams.Algorithm.LIGHTGBM_REGRESSION;
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(algorithm, pmp);
        if (this.isClassification) {
            preTrainParams.lightgbm_classification_grid = space;
        } else {
            preTrainParams.lightgbm_regression_grid = space;
        }
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        boolean hasHPSearch = preTrainParams.gridLength > 1 || space.early_stopping;
        modelingSet.estimatedTrains = hasHPSearch ? preTrainParams.gridLength * gsFolds + 1 : 1;
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.LightGBMParams optimizedParameters = optimized.lightgbm;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.LightGBMHyperParametersSpace space = this.getHyperparametersSpace(preTrainParams);
        space.boosting_type = CategoricalHyperparameterDimension.create(optimizedParameters.boosting_type, "gbdt", "goss");
        space.num_leaves.setToSingleValueGrid(optimizedParameters.num_leaves);
        space.learning_rate.setToSingleValueGrid(optimizedParameters.learning_rate);
        space.n_estimators.setToSingleValueGrid(optimizedParameters.n_estimators);
        space.min_split_gain.setToSingleValueGrid(optimizedParameters.min_split_gain);
        space.min_child_weight.setToSingleValueGrid(optimizedParameters.min_child_weight);
        space.min_child_samples.setToSingleValueGrid(optimizedParameters.min_child_samples);
        space.colsample_bytree.setToSingleValueGrid(optimizedParameters.colsample_bytree);
        space.reg_alpha.setToSingleValueGrid(optimizedParameters.reg_alpha);
        space.reg_lambda.setToSingleValueGrid(optimizedParameters.reg_lambda);
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        if (this.isClassification) {
            target.lightgbm_classification = preTrainParams.lightgbm_classification_grid;
            target.lightgbm_classification.enabled = true;
        } else {
            target.lightgbm_regression = preTrainParams.lightgbm_regression_grid;
            target.lightgbm_regression.enabled = true;
        }
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        if (this.isClassification) {
            target.lightgbm_classification = usedToTrain.lightgbm_classification_grid;
            target.lightgbm_classification.enabled = true;
        } else {
            target.lightgbm_regression = usedToTrain.lightgbm_regression_grid;
            target.lightgbm_regression.enabled = true;
        }
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return this.isClassification;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        switch (coreParams.prediction_type) {
            case REGRESSION: 
            case BINARY_CLASSIFICATION: {
                return true;
            }
        }
        return false;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.LightGBMHyperParametersSpace lightgbmSpace = (PredictionModelingParams.LightGBMHyperParametersSpace)space;
        return lightgbmSpace.boosting_type.getLength() * lightgbmSpace.num_leaves.getLength() * lightgbmSpace.learning_rate.getLength() * lightgbmSpace.n_estimators.getLength() * lightgbmSpace.min_split_gain.getLength() * lightgbmSpace.min_child_weight.getLength() * lightgbmSpace.min_child_samples.getLength() * lightgbmSpace.colsample_bytree.getLength() * lightgbmSpace.reg_alpha.getLength() * lightgbmSpace.reg_lambda.getLength();
    }

    private PredictionModelingParams.LightGBMHyperParametersSpace getHyperparametersSpace(PredictionModelingParams modelingParams) {
        return this.isClassification ? modelingParams.lightgbm_classification : modelingParams.lightgbm_regression;
    }

    private PredictionModelingParams.LightGBMHyperParametersSpace getHyperparametersSpace(PreTrainPredictionModelingParams modelingParams) {
        return this.isClassification ? modelingParams.lightgbm_classification_grid : modelingParams.lightgbm_regression_grid;
    }
}

