/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.spark;

import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionScoringJobDef;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.spark.ml.preprocessing.CoercionHandler$;
import com.dataiku.dip.spark.ml.preprocessing.DateNormalizationHandler$;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.tuple.Pair;
import com.dataiku.scoring.Observation;
import com.dataiku.scoring.Try;
import com.dataiku.scoring.builders.Build;
import com.dataiku.scoring.pipelines.Pipeline;
import com.dataiku.scoring.pipelines.Result;
import com.dataiku.scoring.util.RawObservation;
import java.io.File;
import java.io.Serializable;
import java.net.URL;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.StructType$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.PartialFunction;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.GenTraversableOnce;
import scala.collection.IndexedSeq;
import scala.collection.JavaConversions$;
import scala.collection.Map;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.Buffer$;
import scala.collection.mutable.Iterable$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

public final class ScoringUtils$ {
    public static ScoringUtils$ MODULE$;

    static {
        new ScoringUtils$();
    }

    public RawObservation createObservation(Row row, Seq<String> columns) {
        Observation.Builder builder = Observation.Utils.rawBuilder();
        columns.foreach((Function1 & Serializable & scala.Serializable)col -> {
            double d;
            String string;
            Object object = row.getAs(col);
            if (object instanceof String && (string = (String)object) != null) {
                return builder.with(col, string);
            }
            if (object instanceof Double && BoxesRunTime.boxToDouble((double)(d = BoxesRunTime.unboxToDouble((Object)object))) != null) {
                return builder.with(col, (Number)Predef$.MODULE$.double2Double(d));
            }
            if (object == null) {
                return BoxedUnit.UNIT;
            }
            throw new MatchError(object);
        });
        return (RawObservation)builder.build();
    }

    public StructType scoredSchema(StructType schemaIn, Seq<String> columnsOut) {
        Set inCols = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])schemaIn.fieldNames())).toSet();
        return StructType$.MODULE$.apply((Seq)columnsOut.collect((PartialFunction)new scala.Serializable(inCols, schemaIn){
            public static final long serialVersionUID = 0L;
            private final Set inCols$1;
            private final StructType schemaIn$1;

            public final <A1 extends String, B1> B1 applyOrElse(A1 x1, Function1<A1, B1> function1) {
                A1 A1 = x1;
                if (this.inCols$1.contains((Object)new StringBuilder(9).append("__dku_in_").append(A1).toString())) {
                    StructField qual$1 = this.schemaIn$1.apply(new StringBuilder(9).append("__dku_in_").append(A1).toString());
                    A1 x$1 = A1;
                    DataType x$2 = qual$1.copy$default$2();
                    boolean x$3 = qual$1.copy$default$3();
                    Metadata x$4 = qual$1.copy$default$4();
                    return (B1)qual$1.copy(x$1, x$2, x$3, x$4);
                }
                if (this.inCols$1.contains(A1)) {
                    return (B1)this.schemaIn$1.apply(A1);
                }
                return (B1)function1.apply(x1);
            }

            public final boolean isDefinedAt(String x1) {
                String string = x1;
                if (this.inCols$1.contains((Object)new StringBuilder(9).append("__dku_in_").append(string).toString())) {
                    return true;
                }
                return this.inCols$1.contains((Object)string);
            }
            {
                this.inCols$1 = inCols$1;
                this.schemaIn$1 = schemaIn$1;
            }
        }, Seq$.MODULE$.canBuildFrom()));
    }

    public Seq<String> backedUpSchema(StructType schemaIn, Seq<String> columnsOut) {
        Set inCols = new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])schemaIn.fieldNames())).toSet();
        return (Seq)columnsOut.collect((PartialFunction)new scala.Serializable(inCols){
            public static final long serialVersionUID = 0L;
            private final Set inCols$2;

            public final <A1 extends String, B1> B1 applyOrElse(A1 x1, Function1<A1, B1> function1) {
                A1 A1 = x1;
                if (this.inCols$2.contains((Object)new StringBuilder(9).append("__dku_in_").append(A1).toString())) {
                    return (B1)new StringBuilder(9).append("__dku_in_").append(A1).toString();
                }
                if (this.inCols$2.contains(A1)) {
                    return (B1)A1;
                }
                return (B1)function1.apply(x1);
            }

            public final boolean isDefinedAt(String x1) {
                String string = x1;
                if (this.inCols$2.contains((Object)new StringBuilder(9).append("__dku_in_").append(string).toString())) {
                    return true;
                }
                return this.inCols$2.contains((Object)string);
            }
            {
                this.inCols$2 = inCols$2;
            }
        }, Seq$.MODULE$.canBuildFrom());
    }

    public <RT, R extends Result<RT>> Dataset<Row> doScoreDataFrame(Dataset<Row> data, Pipeline<RT, R> pipeline, Seq<String> columnsIn, Seq<SchemaColumn> columnsOut) {
        Seq columnsOutNames = (Seq)columnsOut.map((Function1 & Serializable & scala.Serializable)x$3 -> x$3.getName(), Seq$.MODULE$.canBuildFrom());
        Seq<String> backedUpColumns = this.backedUpSchema(data.schema(), (Seq<String>)columnsOutNames);
        RDD scored = data.rdd().mapPartitions((Function1 & Serializable & scala.Serializable)partitionRowsIterator -> {
            pipeline.init();
            return partitionRowsIterator.flatMap((Function1 & Serializable & scala.Serializable)row -> Option$.MODULE$.option2Iterable(MODULE$.scoreRow((Row)row, pipeline, columnsIn, backedUpColumns, (Seq<String>)columnsOutNames)));
        }, data.rdd().mapPartitions$default$2(), ClassTag$.MODULE$.apply(Row.class));
        return data.sqlContext().createDataFrame(scored, this.outputSchema(data.schema(), columnsOut, (Seq<Pair<String, Class<Object>>>)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(pipeline.getComputedColumnsTypes(JavaConversions$.MODULE$.deprecated$u0020seqAsJavaList(columnsOutNames)))));
    }

    public DataType toSparkType(Class<?> javaClass, String columnName) {
        boolean bl;
        Class<String> JString = String.class;
        Class<Short> JShort = Short.class;
        Class<Double> JDouble = Double.class;
        Class<Float> JFloat = Float.class;
        Class<?> clazz = javaClass;
        Class<String> clazz2 = JString;
        Class<?> clazz3 = clazz;
        if (!(clazz2 != null ? !clazz2.equals(clazz3) : clazz3 != null)) {
            return StringType$.MODULE$;
        }
        Class<Short> clazz4 = JShort;
        Class<?> clazz5 = clazz;
        if (!(clazz4 != null ? !clazz4.equals(clazz5) : clazz5 != null)) {
            return ShortType$.MODULE$;
        }
        Class<Double> clazz6 = JDouble;
        Class<?> clazz7 = clazz;
        if (!(clazz6 != null ? !clazz6.equals(clazz7) : clazz7 != null)) {
            bl = true;
        } else {
            Class<Float> clazz8 = JFloat;
            Class<?> clazz9 = clazz;
            bl = !(clazz8 != null ? !clazz8.equals(clazz9) : clazz9 != null);
        }
        if (bl) {
            return DoubleType$.MODULE$;
        }
        throw new UnsupportedOperationException(new StringBuilder(48).append("Unsupported class '").append(javaClass).append("' used for computed column '").append(columnName).append("'").toString());
    }

    public StructType outputSchema(StructType schemaIn, Seq<SchemaColumn> columnsOut, Seq<Pair<String, Class<Object>>> computedColumnTypes) {
        Seq computedOutputsFields = (Seq)computedColumnTypes.map((Function1 & Serializable & scala.Serializable)computedOutputType -> new StructField((String)computedOutputType.getKey(), MODULE$.toSparkType((Class)computedOutputType.getValue(), (String)computedOutputType.getKey()), StructField$.MODULE$.apply$default$3(), StructField$.MODULE$.apply$default$4()), Seq$.MODULE$.canBuildFrom());
        return StructType$.MODULE$.apply((Seq)this.scoredSchema(schemaIn, (Seq<String>)((Seq)columnsOut.map((Function1 & Serializable & scala.Serializable)x$4 -> x$4.getName(), Seq$.MODULE$.canBuildFrom()))).$plus$plus((GenTraversableOnce)computedOutputsFields, Seq$.MODULE$.canBuildFrom()));
    }

    public <RT, R extends Result<RT>> Option<Row> scoreRow(Row row, Pipeline<RT, R> pipeline, Seq<String> columnsIn, Seq<String> backedUpColumns, Seq<String> columnsOut) {
        RawObservation obs = this.createObservation(row, columnsIn);
        Try predResults = pipeline.getPredictionResults(obs);
        if (predResults.isSuccess()) {
            Buffer computedOutputsValues = (Buffer)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(pipeline.getComputedColumnsValues(obs, (Result)predResults.get(), JavaConversions$.MODULE$.deprecated$u0020seqAsJavaList(columnsOut))).map((Function1 & Serializable & scala.Serializable)r -> ((Optional)r.getValue()).orElse(null), Buffer$.MODULE$.canBuildFrom());
            return new Some((Object)Row$.MODULE$.apply((Seq)((TraversableLike)backedUpColumns.map((Function1 & Serializable & scala.Serializable)fieldName -> row.getAs(fieldName), Seq$.MODULE$.canBuildFrom())).$plus$plus((GenTraversableOnce)computedOutputsValues, Seq$.MODULE$.canBuildFrom())));
        }
        return None$.MODULE$;
    }

    public Dataset<Row> scoreDataFrame(Dataset<Row> shaken, String modelFolder, Schema outputDsSchema, java.util.Map<String, URL> partitionUrls, PredictionScoringJobDef scoringDef) {
        Nil$ nil$;
        ResolvedClassicalPredictionCoreParams coreParams = (ResolvedClassicalPredictionCoreParams)JSON.parseFile((String)new StringBuilder(17).append(modelFolder).append("/core_params.json").toString(), ResolvedClassicalPredictionCoreParams.class);
        ResolvedClassicalPredictionPreprocessingParams rppp = (ResolvedClassicalPredictionPreprocessingParams)JSON.parseFile((String)new StringBuilder(27).append(modelFolder).append("/rpreprocessing_params.json").toString(), ResolvedClassicalPredictionPreprocessingParams.class);
        PreTrainPredictionModelingParams rpmp = (PreTrainPredictionModelingParams)JSON.parseFile((String)new StringBuilder(22).append(modelFolder).append("/rmodeling_params.json").toString(), PreTrainPredictionModelingParams.class);
        MLTask.BackendType backendType = rpmp.algorithm.backendType;
        MLTask.BackendType backendType2 = MLTask.BackendType.MLLIB;
        Dataset<Row> normalized = !(backendType != null ? !backendType.equals(backendType2) : backendType2 != null) ? shaken : DateNormalizationHandler$.MODULE$.normalizeDates(shaken, (Map<String, FeaturePreprocessingParams>)JavaConversions$.MODULE$.deprecated$u0020mapAsScalaMap(rppp.per_feature));
        PredictionMLTask.PredictionType predictionType = coreParams.prediction_type;
        PredictionMLTask.PredictionType predictionType2 = PredictionMLTask.PredictionType.REGRESSION;
        ObjectRef data = ObjectRef.create(CoercionHandler$.MODULE$.coerceAndRemap(normalized, (Map<String, FeaturePreprocessingParams>)JavaConversions$.MODULE$.deprecated$u0020mapAsScalaMap(rppp.per_feature), (Option<IndexedSeq<String>>)None$.MODULE$, !(predictionType != null ? !predictionType.equals(predictionType2) : predictionType2 != null), false));
        MLFlowUtils.ModelPartitionMode modelPartitionMode = scoringDef.modelPartitionMode;
        MLFlowUtils.ModelPartitionMode modelPartitionMode2 = MLFlowUtils.ModelPartitionMode.PARTITIONED_DISPATCH;
        if (!(modelPartitionMode != null ? !modelPartitionMode.equals(modelPartitionMode2) : modelPartitionMode2 != null)) {
            Predef$.MODULE$.require(coreParams.partitionedModel.isEnabled(), (Function0 & Serializable & scala.Serializable)() -> "Partition dispatch scoring needs partitioned model");
            JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(coreParams.partitionedModel.dimensionNames).foreach((Function1 & Serializable & scala.Serializable)dimension -> {
                data.elem = ((Dataset)data.elem).withColumn(dimension, shaken.apply(new StringBuilder(9).append("__dku_in_").append(dimension).toString()));
                return BoxedUnit.UNIT;
            });
            nil$ = JavaConversions$.MODULE$.collectionAsScalaIterable((Collection)coreParams.partitionedModel.dimensionNames).toSeq();
        } else {
            nil$ = Nil$.MODULE$;
        }
        Nil$ partCols = nil$;
        Seq featureCols = (Seq)((TraversableOnce)JavaConversions$.MODULE$.deprecated$u0020mapAsScalaMap(rppp.per_feature).collect((PartialFunction)new scala.Serializable(){
            public static final long serialVersionUID = 0L;

            public final <A1 extends Tuple2<String, FeaturePreprocessingParams>, B1> B1 applyOrElse(A1 x1, Function1<A1, B1> function1) {
                A1 A1 = x1;
                FeaturePreprocessingParams.Role role = ((FeaturePreprocessingParams)A1._2()).role;
                FeaturePreprocessingParams.Role role2 = FeaturePreprocessingParams.Role.INPUT;
                if (!(role != null ? !role.equals(role2) : role2 != null)) {
                    return (B1)A1._1();
                }
                return (B1)function1.apply(x1);
            }

            public final boolean isDefinedAt(Tuple2<String, FeaturePreprocessingParams> x1) {
                Tuple2<String, FeaturePreprocessingParams> tuple2 = x1;
                FeaturePreprocessingParams.Role role = ((FeaturePreprocessingParams)tuple2._2()).role;
                FeaturePreprocessingParams.Role role2 = FeaturePreprocessingParams.Role.INPUT;
                return !(role != null ? !role.equals(role2) : role2 != null);
            }
        }, Iterable$.MODULE$.canBuildFrom())).toSeq().$plus$plus((GenTraversableOnce)partCols, Seq$.MODULE$.canBuildFrom());
        File resources = new File(modelFolder);
        List columnsOut = outputDsSchema.getColumns();
        Double threshold = scoringDef.desc.overrideModelSpecifiedThreshold ? Predef$.MODULE$.double2Double(scoringDef.desc.forcedClassifierThreshold) : null;
        URL resourcesURL = resources.toURI().toURL();
        PredictionMLTask.PredictionType predictionType3 = coreParams.prediction_type;
        if (PredictionMLTask.PredictionType.REGRESSION.equals(predictionType3)) {
            return this.doScoreDataFrame((Dataset<Row>)((Dataset)data.elem), (Pipeline)Build.regressionPipeline((URL)resourcesURL, partitionUrls, (boolean)false), (Seq<String>)featureCols, (Seq<SchemaColumn>)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(columnsOut));
        }
        if (PredictionMLTask.PredictionType.BINARY_CLASSIFICATION.equals(predictionType3) && rpmp.algorithm.meta.hasProbabilities(rpmp)) {
            return this.doScoreDataFrame((Dataset<Row>)((Dataset)data.elem), (Pipeline)Build.binaryProbabilisticPipeline((URL)resourcesURL, (Double)threshold, partitionUrls, (boolean)false), (Seq<String>)featureCols, (Seq<SchemaColumn>)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(columnsOut));
        }
        if (PredictionMLTask.PredictionType.MULTICLASS.equals(predictionType3) && rpmp.algorithm.meta.hasProbabilities(rpmp)) {
            return this.doScoreDataFrame((Dataset<Row>)((Dataset)data.elem), (Pipeline)Build.multiclassProbabilisticPipeline((URL)resourcesURL, partitionUrls, (boolean)false), (Seq<String>)featureCols, (Seq<SchemaColumn>)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(columnsOut));
        }
        if (PredictionMLTask.PredictionType.BINARY_CLASSIFICATION.equals(predictionType3) ? true : PredictionMLTask.PredictionType.MULTICLASS.equals(predictionType3)) {
            return this.doScoreDataFrame((Dataset<Row>)((Dataset)data.elem), (Pipeline)Build.nonProbabilisticClassificationPipeline((URL)resourcesURL, partitionUrls, (boolean)false), (Seq<String>)featureCols, (Seq<SchemaColumn>)JavaConversions$.MODULE$.deprecated$u0020asScalaBuffer(columnsOut));
        }
        throw new IllegalArgumentException(new StringBuilder(29).append("Unsupported prediction type: ").append(coreParams.prediction_type).toString());
    }

    private ScoringUtils$() {
        MODULE$ = this;
    }
}

