/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLThresholdFunctions;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class PMMLModel {
    @XML.Element
    public PMMLMiningSchema MiningSchema;
    @XML.Element
    public PMMLOutput Output;

    public static class PMMLOutput {
        @XML.Element
        public List<PMMLOutputField> OutputField;

        public static PMMLOutput regression() {
            PMMLOutputField of = new PMMLOutputField("prediction", "predictedValue", "continuous", "double", "className");
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            out.OutputField.add(of);
            return out;
        }

        public static PMMLOutput regression(String className) {
            PMMLOutputField of = new PMMLOutputField("prediction_" + className, "predictedValue", "continuous", "double", className);
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            out.OutputField.add(of);
            return out;
        }

        public static PMMLOutput binaryClassification(String[] classNames, double threshold, boolean sigmoidNormalization) {
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            PMMLThresholdOutputField prediction = PMMLThresholdOutputField.from(classNames, threshold, sigmoidNormalization, new PMMLOutputField("prediction", "decision", "categorical", "string", ""));
            out.OutputField.add(prediction);
            PMMLOutput.addOutputProbaForEachClass(classNames, out);
            return out;
        }

        public static PMMLOutput multiclassClassification(String[] classes) {
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            PMMLOutputField of = new PMMLOutputField("prediction", "predictedValue", "categorical", "string", "");
            out.OutputField.add(of);
            PMMLOutput.addOutputProbaForEachClass(classes, out);
            return out;
        }

        private static void addOutputProbaForEachClass(String[] classNames, PMMLOutput out) {
            for (String className : classNames) {
                PMMLOutputField of = new PMMLOutputField("proba_" + className, "probability", "continuous", "double", className);
                out.OutputField.add(of);
            }
        }

        public static PMMLOutput pseudoProbabilisticWithSegmentId(String[] classes) {
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            for (String c2 : classes) {
                PMMLOutputFieldWithSegment of = new PMMLOutputFieldWithSegment(new PMMLOutputField("prediction_" + c2, "predictedValue", "continuous", "double", c2), c2);
                out.OutputField.add(of);
            }
            return out;
        }

        public static PMMLOutput outputOnlyClassOnePredictedValue(String[] classes) {
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            PMMLOutputField of = new PMMLOutputField("prediction_" + classes[1], "predictedValue", "continuous", "double", classes[1]);
            out.OutputField.add(of);
            return out;
        }

        public static PMMLOutput outputOnlyClassOneProbability(String[] classes) {
            PMMLOutput out = new PMMLOutput();
            out.OutputField = new ArrayList<PMMLOutputField>(1);
            PMMLOutputField of = new PMMLOutputField("prediction_" + classes[1], "probability", "continuous", "double", classes[1]);
            out.OutputField.add(of);
            return out;
        }

        public static class PMMLOutputField {
            @XML.Attribute
            public String name;
            @XML.Attribute
            public String feature;
            @XML.Attribute
            public String optype;
            @XML.Attribute
            public String dataType;
            @XML.Attribute
            public String value;

            public PMMLOutputField(String name, String feature, String optype, String dataType, String value) {
                this.name = name;
                this.feature = feature;
                this.optype = optype;
                this.dataType = dataType;
                this.value = value;
            }

            public PMMLOutputField(PMMLOutputField outputField) {
                this.name = outputField.name;
                this.feature = outputField.feature;
                this.optype = outputField.optype;
                this.dataType = outputField.dataType;
                this.value = outputField.value;
            }
        }

        public static class PMMLThresholdOutputField
        extends PMMLOutputField {
            @XML.Element
            public PMMLThresholdFunctions.PMMLApplyIfGreaterThan Apply;

            public PMMLThresholdOutputField(PMMLOutputField outputField, PMMLThresholdFunctions.PMMLApplyIfGreaterThan apply) {
                super(outputField);
                this.Apply = apply;
                this.feature = "decision";
            }

            public static PMMLThresholdOutputField from(String[] classNames, double threshold, boolean sigmoidNormalization, PMMLOutputField outputField) {
                return new PMMLThresholdOutputField(outputField, PMMLThresholdFunctions.PMMLApplyIfGreaterThan.from(classNames, threshold, sigmoidNormalization));
            }
        }

        public static class PMMLOutputFieldWithSegment
        extends PMMLOutputField {
            @XML.Attribute
            public String segmentId;

            public PMMLOutputFieldWithSegment(PMMLOutputField outputField, String segmentId) {
                super(outputField);
                this.segmentId = segmentId;
            }
        }
    }

    public static class PMMLMiningSchema {
        @XML.Element
        public List<PMMLMiningField> MiningField;

        private PMMLMiningSchema(List<PMMLMiningField> miningField) {
            this.MiningField = miningField;
        }

        public static PMMLMiningSchema fromFields(List<PMMLMiningField> fields) {
            return new PMMLMiningSchema(fields);
        }

        public static PMMLMiningSchema createFrom(ClassificationPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
            String[] cols = pipe.getPreprocessing().getOutputColumns();
            ArrayList<PMMLMiningField> fields = new ArrayList<PMMLMiningField>(cols.length);
            for (Map.Entry e : rppp.per_feature.entrySet()) {
                if (((FeaturePreprocessingParams)e.getValue()).role != FeaturePreprocessingParams.Role.INPUT) continue;
                fields.add(new PMMLMiningField((String)e.getKey()));
            }
            return PMMLMiningSchema.fromFields(fields);
        }

        public static class PMMLMiningField {
            @XML.Attribute
            public String name;
            @XML.Attribute
            public final String usageType = "active";
            @XML.Attribute
            public final String missingValueTreatment = "asIs";
            @XML.Attribute
            public final String invalidValueTreatment = "asIs";

            public PMMLMiningField(String name) {
                this.name = name;
            }
        }
    }
}

