/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.Collections;
import java.util.List;

public class TabICLMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "TabICL";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams usedToTrain) {
        PredictionModelingParams.TabICLHyperparametersSpace space = this.getHyperparametersSpace(usedToTrain);
        return new ModelTrainInfo.PreSearchDescription(usedToTrain).withGridLength(this.getSearchSize(usedToTrain.grid_search_params, space)).withMVParam("n_estimators", space.n_estimators).withMVParam("class_shift", space.class_shift);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.TabICLParams tabiclParams = after.tabicl;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("n_estimators", tabiclParams.n_estimators).withSVParam("class_shift", tabiclParams.class_shift);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.TabICLHyperparametersSpace space = this.getHyperparametersSpace(pmp);
        if (!space.enabled) {
            return;
        }
        checks.checkNumericalDimension(space.n_estimators, "Number of estimators (TabICL)");
        checks.checkPositive(space.class_shift.getLength(), "At least one class shift method must be selected");
        ErrorContext.check((space.random_state >= 0 ? 1 : 0) != 0, (String)"Random state must be >= 0");
        ErrorContext.check((space.n_jobs >= -1 ? 1 : 0) != 0, (String)"n_jobs must be either -1 or a positive number");
        ErrorContext.check(((double)space.softmax_temperature >= 0.1 && space.softmax_temperature <= 1.0f ? 1 : 0) != 0, (String)"Softmax Temperature must be between 0.1 and 1.0");
        ErrorContext.check((space.batch_size > 0 ? 1 : 0) != 0, (String)"batch size must be > 0");
        ErrorContext.check((space.outlier_threshold >= 3.0f && (double)space.outlier_threshold <= 6.0 ? 1 : 0) != 0, (String)"Outlier Threshold must be between 3.0 and 6.0");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.TabICLHyperparametersSpace space = this.getHyperparametersSpace(pmp);
        if (!space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams.Algorithm algorithm = PreTrainPredictionModelingParams.Algorithm.TABICL_CLASSIFICATION;
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(algorithm, pmp);
        preTrainParams.tabicl_classification_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.TabICLParams optimizedParams = optimized.tabicl;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.TabICLHyperparametersSpace space = this.getHyperparametersSpace(preTrainParams);
        space.n_estimators.setToSingleValueGrid(optimizedParams.n_estimators);
        space.class_shift = CategoricalHyperparameterDimension.create(String.valueOf(optimizedParams.class_shift), "true", "false");
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.tabicl_classification.enabled = true;
        target.tabicl_classification = preTrainParams.tabicl_classification_grid;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.tabicl_classification = usedToTrain.tabicl_classification_grid;
        target.tabicl_classification.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams usedToTrain) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.TabICLHyperparametersSpace tabiclSpace = (PredictionModelingParams.TabICLHyperparametersSpace)space;
        return tabiclSpace.n_estimators.getLength() * tabiclSpace.class_shift.getLength();
    }

    private PredictionModelingParams.TabICLHyperparametersSpace getHyperparametersSpace(PredictionModelingParams modelingParams) {
        return modelingParams.tabicl_classification;
    }

    private PredictionModelingParams.TabICLHyperparametersSpace getHyperparametersSpace(PreTrainPredictionModelingParams modelingParams) {
        return modelingParams.tabicl_classification_grid;
    }
}

