/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml;

import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLConstants;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLDerivedField;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLThresholdFunctions;
import com.dataiku.scoring.pipelines.CategoricalEncoder;
import com.dataiku.scoring.pipelines.Dummifier;
import com.dataiku.scoring.pipelines.ImputeWithValue;
import com.dataiku.scoring.pipelines.Normalization;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.Processor;
import com.dataiku.scoring.pipelines.Rescaler;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class PMMLPreprocessing {
    public static List<PMMLDerivedField> getDerivedFields(PreprocessingPipeline pipe, Map<String, Normalization.Action> actions, boolean castToFloat) {
        ArrayList<PMMLDerivedField> fields = new ArrayList<PMMLDerivedField>();
        fields.addAll(PMMLPreprocessing.numericalFeaturesHandling(pipe, actions, castToFloat));
        fields.addAll(PMMLPreprocessing.categoricalFeaturesHandling(pipe, castToFloat));
        return fields;
    }

    private static <T extends Processor> List<T> getProcessors(PreprocessingPipeline pipe, Class<T> clazz) {
        ArrayList<Processor> result = new ArrayList<Processor>();
        for (Processor p : pipe.getStages()) {
            if (!clazz.isInstance(p)) continue;
            result.add(p);
        }
        return result;
    }

    public static String[] normalizedOutputColumns(String[] outputColumns, PreprocessingPipeline pipe) {
        List<Rescaler> rescalers = PMMLPreprocessing.getProcessors(pipe, Rescaler.class);
        HashSet<Object> rescaled = new HashSet<Object>();
        for (Rescaler r : rescalers) {
            for (String string : r.getColumns()) {
                rescaled.add(string);
            }
        }
        HashSet<String> onlyImputed = new HashSet<String>();
        List<ImputeWithValue> imputers = PMMLPreprocessing.getProcessors(pipe, ImputeWithValue.class);
        for (ImputeWithValue imputer : imputers) {
            for (String string : imputer.getColumnMapping().keySet()) {
                if (rescaled.contains(string)) continue;
                onlyImputed.add(string);
            }
        }
        String[] res = new String[outputColumns.length];
        for (int i = 0; i < res.length; ++i) {
            String out = outputColumns[i];
            res[i] = rescaled.contains(out) || onlyImputed.contains(out) ? PMMLDerivedField.NumericalDerivedField.getTransformedName(out) : out;
        }
        return res;
    }

    public static Map<String, List<String>> extractCategoryLevels(PreprocessingPipeline pipe) {
        int i;
        HashMap<String, List<String>> res = new HashMap<String, List<String>>();
        List<Dummifier> dum = PMMLPreprocessing.getProcessors(pipe, Dummifier.class);
        List<CategoricalEncoder> categoricalEncoders = PMMLPreprocessing.getProcessors(pipe, CategoricalEncoder.class);
        for (Dummifier d : dum) {
            String[] col = d.getColumns();
            List lev = d.getLevels();
            for (i = 0; i < col.length; ++i) {
                ArrayList levFull = new ArrayList((Collection)lev.get(i));
                res.put(col[i], levFull);
            }
        }
        for (CategoricalEncoder imp : categoricalEncoders) {
            String[] cols = imp.getColumns();
            List levs = imp.getEncodings();
            for (i = 0; i < cols.length; ++i) {
                res.put(cols[i], new ArrayList(((Map)levs.get(i)).keySet()));
            }
        }
        return res;
    }

    private static List<PMMLDerivedField> numericalFeaturesHandling(PreprocessingPipeline pipe, Map<String, Normalization.Action> actions, boolean castToFloat) {
        List<Rescaler> rescaler = PMMLPreprocessing.getProcessors(pipe, Rescaler.class);
        List<ImputeWithValue> imputes = PMMLPreprocessing.getProcessors(pipe, ImputeWithValue.class);
        if (rescaler.isEmpty()) {
            if (!imputes.isEmpty()) {
                return PMMLPreprocessing.allNumericalImputes(imputes.get(0));
            }
            return Lists.newArrayList();
        }
        if (imputes.isEmpty()) {
            throw new IllegalArgumentException("Found rescaled features without imputation. Not supported for PMML.");
        }
        return PMMLPreprocessing.imputeAndRescale(imputes.get(0), rescaler.get(0), actions, castToFloat);
    }

    private static List<PMMLDerivedField> allNumericalImputes(ImputeWithValue imputer) {
        ArrayList<PMMLDerivedField> fields = new ArrayList<PMMLDerivedField>();
        for (Map.Entry e : imputer.getColumnMapping().entrySet()) {
            Object o = e.getValue();
            if (!(o instanceof Double)) continue;
            fields.add(new ImputedNumerical((String)e.getKey(), (Double)o));
        }
        return fields;
    }

    private static List<PMMLDerivedField> imputeAndRescale(ImputeWithValue impute, Rescaler rescaler, Map<String, Normalization.Action> actions, boolean castToFloat) {
        ArrayList<PMMLDerivedField> fields = new ArrayList<PMMLDerivedField>();
        HashSet<String> rescaledNames = new HashSet<String>();
        Map imputeMapping = impute.getColumnMapping();
        String[] rescalerCols = rescaler.getColumns();
        double[] invscales = rescaler.getInv_scales();
        double[] shifts = rescaler.getShifts();
        for (int i = 0; i < rescalerCols.length; ++i) {
            if (!imputeMapping.containsKey(rescalerCols[i])) {
                throw new IllegalArgumentException("Found rescaled feature '" + rescalerCols[i] + "' without imputation. Not supported for PMML.");
            }
            double shift = shifts[i];
            String column = rescalerCols[i];
            if (actions.containsKey(column) && actions.get(column).equals((Object)Normalization.Action.TO_EPOCH)) {
                shift -= 2.2089888E9;
            }
            fields.add(new Rescaled(column, shift, invscales[i], (Double)imputeMapping.get(rescalerCols[i]), castToFloat));
            rescaledNames.add(column);
        }
        for (Map.Entry e : imputeMapping.entrySet()) {
            String col = (String)e.getKey();
            Object val = e.getValue();
            if (!(val instanceof Double) || rescaledNames.contains(col)) continue;
            fields.add(new ImputedNumerical(col, (Double)val));
        }
        return fields;
    }

    private static List<PMMLDerivedField> categoricalFeaturesHandling(PreprocessingPipeline pipe, boolean castToFloat) {
        List<ImputeWithValue> imputeList = PMMLPreprocessing.getProcessors(pipe, ImputeWithValue.class);
        List<CategoricalEncoder> categoricalEncoders = PMMLPreprocessing.getProcessors(pipe, CategoricalEncoder.class);
        List<Dummifier> dum = PMMLPreprocessing.getProcessors(pipe, Dummifier.class);
        if (imputeList.isEmpty()) {
            if (!dum.isEmpty() || !categoricalEncoders.isEmpty()) {
                throw new IllegalArgumentException("Found preprocessed categorical features without imputation. Not supported for PMML.");
            }
            return Lists.newArrayList();
        }
        ImputeWithValue impute = imputeList.get(0);
        if (dum.isEmpty() && categoricalEncoders.isEmpty()) {
            return PMMLPreprocessing.allCategoricalImputes(impute);
        }
        return PMMLPreprocessing.categoricalHelper(impute, dum.isEmpty() ? null : dum.get(0), categoricalEncoders.isEmpty() ? null : categoricalEncoders.get(0), castToFloat);
    }

    private static List<PMMLDerivedField> categoricalHelper(ImputeWithValue impute, Dummifier dummifier, CategoricalEncoder categoricalEncoder, boolean castToFloat) {
        String[] cols;
        Map imputeMapping = impute.getColumnMapping();
        ArrayList<PMMLDerivedField> fields = new ArrayList<PMMLDerivedField>();
        HashSet<String> processed = new HashSet<String>();
        if (dummifier != null) {
            cols = dummifier.getColumns();
            List levels = dummifier.getLevels();
            boolean[] withOthers = dummifier.getWithOthers();
            for (int i = 0; i < cols.length; ++i) {
                Object mapping;
                String col = cols[i];
                Object o = imputeMapping.get(col);
                if (o == null) {
                    fields.add(new Dummy(col, "N/A", 1.0));
                } else {
                    mapping = new HashMap<String, Double>();
                    mapping.put((String)o, 1.0);
                    fields.add(new MappedValues(Dummifier.dummifyName((String)col, (String)((String)o)), col, (Map<String, Double>)mapping, 0.0, 1.0, castToFloat));
                    HashMap mappingNa = new HashMap();
                    mappingNa.put("__dummy", 0.0);
                    fields.add(new MappedValues(Dummifier.dummifyName((String)col, (String)"N/A"), col, mappingNa, 0.0, 0.0, castToFloat));
                }
                if (withOthers[i]) {
                    mapping = new HashMap();
                    for (String lev : (Set)levels.get(i)) {
                        mapping.put(lev, 0.0);
                    }
                    fields.add(new MappedValues(Dummifier.dummifyName((String)col, (String)"__Others__"), col, (Map<String, Double>)mapping, 1.0, o != null && mapping.get(o) == null ? 1.0 : 0.0, castToFloat));
                }
                for (String s : (Set)levels.get(i)) {
                    if (o != null && ((String)o).equals(s)) continue;
                    fields.add(new Dummy(col, s, 0.0));
                }
                processed.add(col);
            }
        }
        if (categoricalEncoder != null) {
            cols = categoricalEncoder.getColumns();
            String[][] output = categoricalEncoder.getOutputNames();
            List mappings = categoricalEncoder.getEncodings();
            double[][] defaults = categoricalEncoder.getDefaults();
            for (int i = 0; i < cols.length; ++i) {
                String feature = cols[i];
                Object imputeValue = imputeMapping.get(feature);
                if (imputeValue == null) {
                    throw new IllegalArgumentException("No imputed value for target-encoded feature (" + feature + "), unsupported for PMML.");
                }
                Map mapping = (Map)mappings.get(i);
                double[] missingVals = mapping.getOrDefault(imputeValue, categoricalEncoder.getDefaults()[i]);
                for (int j = 0; j < missingVals.length; ++j) {
                    HashMap<String, Double> localMapping = new HashMap<String, Double>();
                    for (String k : mapping.keySet()) {
                        localMapping.put(k, ((double[])mapping.get(k))[j]);
                    }
                    fields.add(new MappedValues(output[i][j], feature, localMapping, defaults[i][j], missingVals[j], castToFloat));
                }
                processed.add(feature);
            }
        }
        for (String s : imputeMapping.keySet()) {
            if (!(imputeMapping.get(s) instanceof String) || processed.contains(s)) continue;
            throw new IllegalArgumentException("Found imputed feature without processing : " + s);
        }
        return fields;
    }

    private static List<PMMLDerivedField> allCategoricalImputes(ImputeWithValue imputer) {
        ArrayList<PMMLDerivedField> fields = new ArrayList<PMMLDerivedField>();
        for (Map.Entry e : imputer.getColumnMapping().entrySet()) {
            Object o = e.getValue();
            if (!(o instanceof String)) continue;
            fields.add(new ImputedCategorical((String)e.getKey(), (String)o));
        }
        return fields;
    }

    public static class ImputedNumerical
    extends PMMLDerivedField.NumericalDerivedField {
        @XML.Element
        public PMMLDerivedField.PMMLFieldRefWithMapMissingTo FieldRef = new PMMLDerivedField.PMMLFieldRefWithMapMissingTo();

        public ImputedNumerical(String name, double mapMissingTo) {
            super(ImputedNumerical.getTransformedName(name));
            this.FieldRef.field = name;
            this.FieldRef.mapMissingTo = "" + mapMissingTo;
        }
    }

    public static class Rescaled
    extends PMMLDerivedField.NumericalDerivedField {
        @XML.Element
        public PMMLThresholdFunctions.PMMLApplyMultiplication Apply = new PMMLThresholdFunctions.PMMLApplyMultiplication();

        public Rescaled(String name, double offset, double invScale, double mapMissingTo, boolean castToFloat) {
            super(Rescaled.getTransformedName(name), castToFloat);
            this.Apply.Apply = PMMLThresholdFunctions.PMMLApplySubtraction.from(name, offset);
            this.Apply.Constant = PMMLConstants.PMMLDoubleConstant.from(invScale);
            this.Apply.mapMissingTo = (mapMissingTo - offset) * invScale;
        }
    }

    public static class Dummy
    extends PMMLDerivedField.NumericalDerivedField {
        @XML.Element
        public PMMLNormDiscrete NormDiscrete = new PMMLNormDiscrete();

        public Dummy(String inputCol, String level, double mapMissingTo) {
            super(Dummifier.dummifyName((String)inputCol, (String)level));
            this.NormDiscrete.field = inputCol;
            this.NormDiscrete.value = level;
            this.NormDiscrete.mapMissingTo = mapMissingTo;
        }

        public static class PMMLNormDiscrete {
            @XML.Attribute
            public String field;
            @XML.Attribute
            public String value;
            @XML.Attribute
            public double mapMissingTo;
            @XML.Attribute
            public String method = "indicator";
        }
    }

    static class MappedValues
    extends PMMLDerivedField.NumericalDerivedField {
        @XML.Element
        PMMLMapValues MapValues;

        public MappedValues(String outputName, String col, Map<String, Double> mapping, double defaultValue, double mapMissingTo, boolean castToFloat) {
            super(outputName);
            this.MapValues = new PMMLMapValues(col, mapping, defaultValue, mapMissingTo, castToFloat);
        }

        static class PMMLMapValues {
            @XML.Attribute
            String outputColumn = "value";
            @XML.Attribute
            String dataType = "double";
            @XML.Attribute
            double mapMissingTo;
            @XML.Attribute
            double defaultValue;
            @XML.Element
            PMMLFieldColumnPair FieldColumnPair;
            @XML.Element
            PMMLInlineTable InlineTable;

            public PMMLMapValues(String col, Map<String, Double> mapping, double defaultValue, double mapMissingTo, boolean castToFloat) {
                this.defaultValue = defaultValue;
                this.mapMissingTo = mapMissingTo;
                this.FieldColumnPair = new PMMLFieldColumnPair();
                this.FieldColumnPair.field = col;
                this.InlineTable = new PMMLInlineTable(mapping);
                if (castToFloat) {
                    this.dataType = "float";
                }
            }

            static class PMMLFieldColumnPair {
                @XML.Attribute
                String field;
                @XML.Attribute
                String column = "category";

                PMMLFieldColumnPair() {
                }
            }

            static class PMMLInlineTable {
                @XML.Element
                List<Row> row = new ArrayList<Row>();

                public PMMLInlineTable(Map<String, Double> mapping) {
                    for (Map.Entry<String, Double> e : mapping.entrySet()) {
                        Row r = new Row();
                        r.category = e.getKey();
                        r.value = e.getValue();
                        this.row.add(r);
                    }
                }

                static class Row {
                    @XML.Element
                    String category;
                    @XML.Element
                    Double value;

                    Row() {
                    }
                }
            }
        }
    }

    public static class ImputedCategorical
    extends PMMLDerivedField.CategoricalDerivedField {
        @XML.Element
        public PMMLDerivedField.PMMLFieldRefWithMapMissingTo FieldRef = new PMMLDerivedField.PMMLFieldRefWithMapMissingTo();

        public ImputedCategorical(String name, String mapMissingTo) {
            super(name);
            this.FieldRef.field = name;
            this.FieldRef.mapMissingTo = mapMissingTo;
        }
    }
}

