/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;

public class JavaMultilayerPerceptronClassifierExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaMultilayerPerceptronClassifierExample").getOrCreate();
        String path = "data/mllib/sample_multiclass_classification_data.txt";
        Dataset dataFrame = spark.read().format("libsvm").load(path);
        Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
        Dataset train = splits[0];
        Dataset test = splits[1];
        int[] layers = new int[]{4, 5, 4, 3};
        MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100);
        MultilayerPerceptronClassificationModel model = (MultilayerPerceptronClassificationModel)trainer.fit(train);
        Dataset result = model.transform(test);
        Dataset predictionAndLabels = result.select("prediction", new String[]{"label"});
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy");
        System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));
        spark.stop();
    }
}

