/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.ml.shared.ParameterAutoCompleter;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.ArrayList;
import java.util.List;

public class LassoMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        if (rpmp.lasso_grid != null && this.getSearchSize(rpmp.grid_search_params, rpmp.lasso_grid) > 1) {
            return "Lasso (L1) regression (Grid)";
        }
        return "Lasso (L1) regression";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        assert (rpmp.lasso_grid != null);
        ModelTrainInfo.PreSearchDescription ret = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, rpmp.lasso_grid));
        if (rpmp.lasso_grid.alphaMode == PredictionModelingParams.LassoSelectAlphaMode.MANUAL) {
            ret.withMVParam("alpha", rpmp.lasso_grid.alpha);
        } else {
            ret.withSVParam("alpha", "Auto-optimized");
        }
        return ret;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        if ((double)after.lasso.alpha == 0.0) {
            return new ModelTrainInfo.PostSearchDescription().withSVParam("alpha", "Auto-optimized");
        }
        return new ModelTrainInfo.PostSearchDescription().withSVParam("alpha", Float.valueOf(after.lasso.alpha));
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.LassoHyperparametersSpace rp = pmp.lasso_regression;
        if (rp == null || !rp.enabled) {
            return;
        }
        if (rp.alphaMode == PredictionModelingParams.LassoSelectAlphaMode.AUTO_CV) {
            checks.addWarningSparse("Lasso regression with auto-select alpha");
        } else if (rp.alphaMode == PredictionModelingParams.LassoSelectAlphaMode.AUTO_IC) {
            checks.addWarningSparse("Lasso regression with auto-select alpha");
        } else {
            checks.checkNumericalDimension(rp.alpha, "Alpha regularization coefficient (Lasso regression)");
        }
    }

    @Override
    public ParameterAutoCompleter autoCompleter() {
        return new ParameterAutoCompleter.DummyDropAutoCompleter();
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return false;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.LassoHyperparametersSpace lasso_regression = (PredictionModelingParams.LassoHyperparametersSpace)space;
        return lasso_regression.alphaMode == PredictionModelingParams.LassoSelectAlphaMode.MANUAL ? lasso_regression.alpha.getLength() : 1;
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.LassoHyperparametersSpace rp = pmp.lasso_regression;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (rp == null || !rp.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.LASSO_REGRESSION, pmp);
        rcmp.lasso_grid = rp;
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, rp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.lasso_grid.alpha.setToSingleValueGrid(Double.valueOf(optimized.lasso.alpha));
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.lasso_regression = preTrain.lasso_grid;
        target.lasso_regression.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.lasso_regression = usedToTrain.lasso_grid;
        target.lasso_regression.enabled = true;
    }
}

