/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;

public abstract class OutcomeStats {
    StatByTreatedState<Integer> totalRowCounts = new StatByTreatedState<Integer>(0, 0);
    List<StatByTreatedState<Integer>> rowCountsByOutcomeBins = new ArrayList<StatByTreatedState<Integer>>();

    public OutcomeStats(int nbOutcomeBins) {
        for (int binIdx = 0; binIdx < nbOutcomeBins; ++binIdx) {
            this.rowCountsByOutcomeBins.add(new StatByTreatedState<Integer>(0, 0));
        }
    }

    protected void incrementCounts(String treatmentVal, boolean isControl, int outcomeBinIdx) {
        this.totalRowCounts.perTreatment.put(treatmentVal, this.totalRowCounts.perTreatment.getOrDefault(treatmentVal, 0) + 1);
        this.rowCountsByOutcomeBins.get((int)outcomeBinIdx).perTreatment.put(treatmentVal, this.rowCountsByOutcomeBins.get((int)outcomeBinIdx).perTreatment.getOrDefault(treatmentVal, 0) + 1);
        if (isControl) {
            StatByTreatedState<Integer> statByTreatedState = this.totalRowCounts;
            Integer n = (Integer)statByTreatedState.control;
            statByTreatedState.control = (Integer)statByTreatedState.control + 1;
            statByTreatedState = this.rowCountsByOutcomeBins.get(outcomeBinIdx);
            n = (Integer)statByTreatedState.control;
            statByTreatedState.control = (Integer)statByTreatedState.control + 1;
        } else {
            StatByTreatedState<Integer> statByTreatedState = this.totalRowCounts;
            Integer n = (Integer)statByTreatedState.treated;
            statByTreatedState.treated = (Integer)statByTreatedState.treated + 1;
            statByTreatedState = this.rowCountsByOutcomeBins.get(outcomeBinIdx);
            n = (Integer)statByTreatedState.treated;
            statByTreatedState.treated = (Integer)statByTreatedState.treated + 1;
        }
    }

    public static class StatByTreatedState<T> {
        public T control;
        public T treated;
        public Map<String, T> perTreatment;

        public StatByTreatedState(T control, T treated) {
            this.control = control;
            this.treated = treated;
            this.perTreatment = new HashMap<String, T>();
        }
    }

    public static class Regression
    extends OutcomeStats {
        double[] outcomeIntervalBounds = new double[5];
        StatByTreatedState<Double> outcomeSums = new StatByTreatedState<Double>(0.0, 0.0);

        public Regression() {
            super(4);
        }

        private void updateOutcomeSum(String treatmentVal, boolean isControl, double rowOutcome) {
            this.outcomeSums.perTreatment.put(treatmentVal, this.outcomeSums.perTreatment.getOrDefault(treatmentVal, 0.0) + rowOutcome);
            if (isControl) {
                StatByTreatedState<Double> statByTreatedState = this.outcomeSums;
                statByTreatedState.control = (Double)statByTreatedState.control + rowOutcome;
            } else {
                StatByTreatedState<Double> statByTreatedState = this.outcomeSums;
                statByTreatedState.treated = (Double)statByTreatedState.treated + rowOutcome;
            }
        }

        public static Regression getOutcomeStatsByInterquartileInterval(MemTable table, PredictionMLTask.CausalPredictionMLTask task, String controlValue, boolean dropMissingTreatmentValues) {
            MemColumn outcomeCol = table.column(task.targetVariable);
            MemColumn treatmentCol = table.column(task.treatmentVariable);
            List rowsSortedByOutcome = table.rows.stream().filter(row -> {
                double outcomeVal = row.getAsDoubleOrNaN(outcomeCol);
                return !Double.isNaN(outcomeVal) && !Double.isInfinite(outcomeVal);
            }).sorted(Comparator.comparingDouble(row -> row.getAsDoubleOrNaN(outcomeCol))).collect(Collectors.toList());
            double[] validSortedOutcomeValues = rowsSortedByOutcome.stream().mapToDouble(row -> row.getAsDoubleOrNaN(outcomeCol)).toArray();
            Regression outcomeStats = new Regression();
            double nbRowsPerInterquartileRange = (double)rowsSortedByOutcome.size() / 4.0;
            outcomeStats.outcomeIntervalBounds[0] = validSortedOutcomeValues[0];
            for (int interquartileIntervalIdx = 0; interquartileIntervalIdx < 4; ++interquartileIntervalIdx) {
                int lowerRowIdx = (int)((double)interquartileIntervalIdx * nbRowsPerInterquartileRange);
                int upperRowIdx = (int)((double)(interquartileIntervalIdx + 1) * nbRowsPerInterquartileRange);
                outcomeStats.outcomeIntervalBounds[interquartileIntervalIdx + 1] = validSortedOutcomeValues[upperRowIdx - 1];
                for (int rowIdx = lowerRowIdx; rowIdx < upperRowIdx; ++rowIdx) {
                    MemRow row2 = (MemRow)rowsSortedByOutcome.get(rowIdx);
                    String treatmentVal = StringUtils.defaultIfBlank((String)row2.get(treatmentCol), (String)"");
                    boolean isControl = controlValue.equals(treatmentVal);
                    if ("".equals(treatmentVal) && !isControl && dropMissingTreatmentValues) continue;
                    outcomeStats.incrementCounts(treatmentVal, isControl, interquartileIntervalIdx);
                    double outcomeVal = row2.getAsDoubleOrNaN(outcomeCol);
                    outcomeStats.updateOutcomeSum(treatmentVal, isControl, outcomeVal);
                }
            }
            return outcomeStats;
        }
    }

    public static class BinaryClassification
    extends OutcomeStats {
        List<String> outcomeClasses = new ArrayList<String>();

        public BinaryClassification() {
            super(2);
        }

        public static BinaryClassification getOutcomeStatsPerClass(MemTable table, PredictionMLTask.CausalPredictionMLTask task, String controlValue, boolean dropMissingTreatmentValues) {
            MemColumn outcomeCol = table.column(task.targetVariable);
            MemColumn treatmentCol = table.column(task.treatmentVariable);
            BinaryClassification outcomeStats = new BinaryClassification();
            outcomeStats.outcomeClasses = task.preprocessing.target_remapping.stream().map(mapping -> mapping.sourceValue).collect(Collectors.toList());
            for (int i = 0; i < table.nrows(); ++i) {
                String outcomeVal;
                MemRow row = table.rows.get(i);
                String treatmentVal = StringUtils.defaultIfBlank((String)row.get(treatmentCol), (String)"");
                boolean isControl = controlValue.equals(treatmentVal);
                if ("".equals(treatmentVal) && !isControl && dropMissingTreatmentValues || StringUtils.isBlank((String)(outcomeVal = row.get(outcomeCol)))) continue;
                int outcomeBinIdx = outcomeStats.outcomeClasses.indexOf(outcomeVal);
                outcomeStats.incrementCounts(treatmentVal, isControl, outcomeBinIdx);
            }
            return outcomeStats;
        }
    }
}

