/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.spark;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingAlgorithmMeta;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import java.util.ArrayList;
import java.util.List;

public class SparklingGLMMeta
extends SparklingAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "GLM (H2O)";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        return new ModelTrainInfo.PreSearchDescription(rpmp).withSVParam("alpha", Float.valueOf(rpmp.glm_sparkling_grid.alpha)).withSVParam("lambda", Float.valueOf(rpmp.glm_sparkling_grid.lambda));
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        return new ModelTrainInfo.PostSearchDescription().withSVParam("alpha", Float.valueOf(before.glm_sparkling_grid.alpha)).withSVParam("lambda", Float.valueOf(before.glm_sparkling_grid.lambda));
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        if (pmp.glm_sparkling == null || !pmp.glm_sparkling.enabled) {
            return;
        }
        for (FeaturePreprocessingParams par : task.getPreprocessingParams().per_feature.values()) {
            if (par.role == FeaturePreprocessingParams.Role.INPUT || par.type != FeaturePreprocessingParams.FeatureType.CATEGORY || ((CatFeaturePreprocessingParams)par).category_handling != CatFeaturePreprocessingParams.CategoryHandlingMethod.NONE) continue;
            checks.addWarning("Unhandled categorical feature in H2O GLM", "H2O's GLM may give inconsistent result if no category handling is chosen (in particular if unknown categories are encountered while scoring). You may want to choose another processing method if such a situation may arise.");
        }
        pmp.glm_sparkling.validate();
        checks.addWarningSparse("GLM (H2O)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (pmp.glm_sparkling == null || !pmp.glm_sparkling.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.SPARKLING_GLM, pmp);
        rcmp.glm_sparkling_grid = pmp.glm_sparkling;
        ret.add(new WorkSet.ModelingSet(rcmp));
        return ret;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        return this.getCopyWithGridStrategy(usedToTrain);
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.glm_sparkling = preTrain.glm_sparkling_grid;
        target.glm_sparkling.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.glm_sparkling = usedToTrain.glm_sparkling_grid;
        target.glm_sparkling.enabled = true;
    }
}

