/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.ml.shared.ParameterAutoCompleter;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.ArrayList;
import java.util.List;

public class LarsMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "LASSO-LARS";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        String info = rpmp.lars_grid.max_features > 0 ? "" + rpmp.lars_grid.max_features : "no limit";
        return new ModelTrainInfo.PreSearchDescription(rpmp).withSVParam("Max features", info);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        String info = before.lars_grid.max_features > 0 ? "" + before.lars_grid.max_features : "no limit";
        return new ModelTrainInfo.PostSearchDescription().withSVParam("Max features", info);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.LarsHyperparametersSpace lp = pmp.lars_params;
        checks.checkNonNegative(lp.max_features, "Lasso-path Max Features must be nonnegative");
        checks.checkPositive(lp.K, "Lasso-path number of tested values must be positive");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.LarsHyperparametersSpace lp = pmp.lars_params;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (lp == null || !lp.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.LARS, pmp);
        rcmp.lars_grid = new PredictionModelingParams.LarsHyperparametersSpace();
        rcmp.lars_grid.max_features = lp.max_features;
        rcmp.lars_grid.K = lp.K;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        ms.estimatedTrains = 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public ParameterAutoCompleter autoCompleter() {
        return new ParameterAutoCompleter.DummyDropAutoCompleter();
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        return this.getCopyWithGridStrategy(usedToTrain);
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.lars_params = preTrain.lars_grid;
        target.lars_params.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.lars_params = usedToTrain.lars_grid;
        target.lars_params.enabled = true;
    }
}

