/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.ArrayList;
import java.util.List;

public class NeuralNetworkMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Single Layer Perceptron";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.NeuralNetworkHyperparametersSpace space = rpmp.neural_network_grid;
        return new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("layer_sizes", space.layer_sizes);
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        ModelTrainInfo.PostSearchDescription ps2 = new ModelTrainInfo.PostSearchDescription();
        return ps2.withSVParam("layer_sizes", after.neural_network.layer_sizes).withSVParam("activation", before.neural_network_grid.activation).withSVParam("alpha", Float.valueOf(before.neural_network_grid.alpha)).withSVParam("solver", before.neural_network_grid.solver);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.NeuralNetworkHyperparametersSpace nn = pmp.neural_network;
        checks.checkNumericalDimension(nn.layer_sizes, "Layer size");
        checks.checkNonNegative(nn.alpha, "Neural network alpha must be positive.");
        checks.checkPositive(nn.batch_size, "Batch size should be positive.");
        checks.checkAllPositive(new float[]{nn.beta_1, nn.beta_2, nn.epsilon}, "Beta parameters and epsilon should all be positive.");
        checks.checkPositive(nn.learning_rate_init, "Init learning rate should be positive");
        checks.checkPositive(nn.momentum, "Momentum should be positive");
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.NeuralNetworkHyperparametersSpace neural_network_space = (PredictionModelingParams.NeuralNetworkHyperparametersSpace)space;
        return neural_network_space.layer_sizes.getLength();
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.NeuralNetworkHyperparametersSpace rp = pmp.neural_network;
        ArrayList<WorkSet.ModelingSet> ret = new ArrayList<WorkSet.ModelingSet>();
        if (rp == null || !rp.enabled) {
            return ret;
        }
        PreTrainPredictionModelingParams rcmp = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.NEURAL_NETWORK, pmp);
        rcmp.neural_network_grid = rp;
        rcmp.gridLength = this.getSearchSize(rcmp.grid_search_params, rp);
        WorkSet.ModelingSet ms = new WorkSet.ModelingSet(rcmp);
        ms.estimatedTrains = rcmp.gridLength > 1 ? rcmp.gridLength * gsFolds + 1 : 1;
        ret.add(ms);
        return ret;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    public boolean isJavaCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isPythonCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public boolean isSQLCompatible(ResolvedClassicalPredictionCoreParams coreParams) {
        return true;
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams params = this.getCopyWithGridStrategy(usedToTrain);
        if (optimized.neural_network != null && optimized.neural_network.layer_sizes > 0) {
            params.neural_network_grid.layer_sizes.setToSingleValueGrid(Long.valueOf(optimized.neural_network.layer_sizes));
        }
        return params;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrain = this.regridifyToPreTrain(optimized, usedToTrain);
        target.neural_network = preTrain.neural_network_grid;
        target.neural_network.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.neural_network = usedToTrain.neural_network_grid;
        target.neural_network.enabled = true;
    }
}

