/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.spark.ml.prediction.ensembles;

import com.dataiku.dip.analysis.model.prediction.EnsembleParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.spark.ml.prediction.ensembles.AverageEnsembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.Ensembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.EnsemblerModel;
import com.dataiku.dip.spark.ml.prediction.ensembles.LinearEnsembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.LogisticEnsembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.MedianEnsembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.ProbaAverageEnsembler;
import com.dataiku.dip.spark.ml.prediction.ensembles.VotingEnsembler;
import java.io.Serializable;
import java.util.List;
import org.apache.log4j.Logger;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.JavaConversions$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

public final class Ensemble$
implements scala.Serializable {
    public static Ensemble$ MODULE$;
    private final Logger logger;
    private final String JOIN_ID;
    private final String PRED_PREFIX;
    private final String PROBA_PREFIX;

    static {
        new Ensemble$();
    }

    public String $lessinit$greater$default$4() {
        return Identifiable$.MODULE$.randomUID("ensemble");
    }

    public Logger logger() {
        return this.logger;
    }

    public String JOIN_ID() {
        return this.JOIN_ID;
    }

    public String PRED_PREFIX() {
        return this.PRED_PREFIX;
    }

    public String PROBA_PREFIX() {
        return this.PROBA_PREFIX;
    }

    public Dataset<Row> zipWithId(Dataset<Row> df) {
        return df.sqlContext().createDataFrame(df.rdd().zipWithIndex().map((Function1 & Serializable & scala.Serializable)ln -> Row$.MODULE$.fromSeq((Seq)((Row)ln._1()).toSeq().$plus$plus((GenTraversableOnce)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapLongArray(new long[]{ln._2$mcJ$sp()})), Seq$.MODULE$.canBuildFrom())), ClassTag$.MODULE$.apply(Row.class)), new StructType((StructField[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])df.schema().fields())).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])new StructField[]{new StructField(this.JOIN_ID(), (DataType)LongType$.MODULE$, false, StructField$.MODULE$.apply$default$4())})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(StructField.class)))));
    }

    public Dataset<Row> generateAllPredictions(Dataset<Row> dataset, Seq<PipelineModel> pipelines) {
        Dataset<Row> withIndex = this.zipWithId(dataset);
        return (Dataset)((Tuple2)pipelines.foldLeft((Object)new Tuple2((Object)BoxesRunTime.boxToInteger((int)0), withIndex), (Function2 & Serializable & scala.Serializable)(x0$1, x1$1) -> {
            Tuple2 tuple2 = new Tuple2(x0$1, x1$1);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                PipelineModel pipe = (PipelineModel)tuple2._2();
                if (tuple22 != null) {
                    int i = tuple22._1$mcI$sp();
                    Dataset data = (Dataset)tuple22._2();
                    Dataset predIndex = pipe.transform(withIndex).withColumnRenamed("prediction", new StringBuilder(0).append(MODULE$.PRED_PREFIX()).append(i).toString()).select(new StringBuilder(0).append(MODULE$.PRED_PREFIX()).append(i).toString(), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{MODULE$.JOIN_ID()}));
                    return new Tuple2((Object)BoxesRunTime.boxToInteger((int)(i + 1)), (Object)data.join(predIndex, MODULE$.JOIN_ID()));
                }
            }
            throw new MatchError((Object)tuple2);
        }))._2();
    }

    public Dataset<Row> generateAllProbas(Dataset<Row> dataset, Seq<PipelineModel> pipelines) {
        Dataset<Row> withIndex = this.zipWithId(dataset);
        return (Dataset)((Tuple2)pipelines.foldLeft((Object)new Tuple2((Object)BoxesRunTime.boxToInteger((int)0), withIndex), (Function2 & Serializable & scala.Serializable)(x0$1, x1$1) -> {
            Tuple2 tuple2 = new Tuple2(x0$1, x1$1);
            if (tuple2 != null) {
                Tuple2 tuple22 = (Tuple2)tuple2._1();
                PipelineModel pipe = (PipelineModel)tuple2._2();
                if (tuple22 != null) {
                    int i = tuple22._1$mcI$sp();
                    Dataset data = (Dataset)tuple22._2();
                    Dataset predIndex = pipe.transform(withIndex).withColumnRenamed("probability", new StringBuilder(0).append(MODULE$.PROBA_PREFIX()).append(i).toString()).select(new StringBuilder(0).append(MODULE$.PROBA_PREFIX()).append(i).toString(), (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{MODULE$.JOIN_ID()}));
                    return new Tuple2((Object)BoxesRunTime.boxToInteger((int)(i + 1)), (Object)data.join(predIndex, MODULE$.JOIN_ID()));
                }
            }
            throw new MatchError((Object)tuple2);
        }))._2();
    }

    public Ensembler<? extends EnsemblerModel> getEnsembler(EnsembleParams ens, ResolvedPredictionPreprocessingParams rppp) {
        EnsembleParams.EnsembleMethod ensembleMethod = ens.method;
        if (EnsembleParams.EnsembleMethod.AVERAGE.equals(ensembleMethod)) {
            return new AverageEnsembler();
        }
        if (EnsembleParams.EnsembleMethod.PROBA_AVERAGE.equals(ensembleMethod)) {
            return new ProbaAverageEnsembler();
        }
        if (EnsembleParams.EnsembleMethod.MEDIAN.equals(ensembleMethod)) {
            return new MedianEnsembler();
        }
        if (EnsembleParams.EnsembleMethod.VOTE.equals(ensembleMethod)) {
            return new VotingEnsembler(((ResolvedPredictionPreprocessingParams)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer((List)ens.preprocessing_params).head()).target_remapping.size());
        }
        if (EnsembleParams.EnsembleMethod.LINEAR_MODEL.equals(ensembleMethod)) {
            return new LinearEnsembler();
        }
        if (EnsembleParams.EnsembleMethod.LOGISTIC_MODEL.equals(ensembleMethod)) {
            return new LogisticEnsembler(ens.proba_inputs, rppp.target_remapping.size());
        }
        throw new MatchError((Object)ensembleMethod);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private Ensemble$() {
        MODULE$ = this;
        this.logger = Logger.getLogger((String)"dku.spark.mllib");
        this.JOIN_ID = "__join_id";
        this.PRED_PREFIX = "__pred_";
        this.PROBA_PREFIX = "__proba_";
    }
}

