/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.spark;

import com.dataiku.dip.analysis.ml.shared.ParameterAutoCompleter;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibAlgorithmMeta;
import com.dataiku.dip.utils.ErrorContext;
import java.util.ArrayList;
import java.util.List;

public class MLLibLogisticRegressionMeta
extends MLLibAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Logistic Regression (MLLib)";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        return new ModelTrainInfo.PreSearchDescription(rpmp).withMVParam("lambda", rpmp.mllib_logit_grid.reg_param).withSVParam("max_iters", rpmp.mllib_logit_grid.max_iter).withMVParam("alpha", rpmp.mllib_logit_grid.enet_param);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("lambda", after.mllib_logit.reg_param).withSVParam("max_iters", before.mllib_logit_grid.max_iter).withSVParam("alpha", after.mllib_logit.enet_param);
    }

    @Override
    public ParameterAutoCompleter autoCompleter() {
        return new ParameterAutoCompleter.DummyDropAutoCompleter();
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.MLLibLogisticRegressionGridParams lp = pmp.mllib_logit;
        if (lp == null || !lp.enabled) {
            return;
        }
        if (task.getPreprocessingParams().target_remapping.size() > 2) {
            checks.addWarningSparse("Multiclass logistic regression (MLLib)");
        }
        ErrorContext.check((lp.max_iter > 0 ? 1 : 0) != 0, (String)"MLLib Logistic regression: the max iterations parameter must be > 0");
        checks.checkNumericalDimension(lp.reg_param, "Regularization parameter (Logistic Regression)");
        checks.checkNumericalDimension(lp.enet_param, "Elastic net parameter (Logistic Regression)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.MLLibLogisticRegressionGridParams lp = pmp.mllib_logit;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (lp == null || !lp.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.MLLIB_LOGISTIC_REGRESSION, pmp);
        rcmp.mllib_logit_grid = lp;
        rcmp.gridLength = lp.reg_param.getLength() * lp.enet_param.getLength();
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPMMLCompatible() {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams ret = this.getCopyWithGridStrategy(usedToTrain);
        ret.mllib_logit_grid.enet_param.setToSingleValueGrid(optimized.mllib_logit.enet_param);
        ret.mllib_logit_grid.reg_param.setToSingleValueGrid(optimized.mllib_logit.reg_param);
        return ret;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.mllib_logit = preTrain.mllib_logit_grid;
        target.mllib_logit.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.mllib_logit = usedToTrain.mllib_logit_grid;
        target.mllib_logit.enabled = true;
    }
}

