/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction;

import com.dataiku.dip.analysis.model.core.CustomMetricResult;
import com.dataiku.dip.analysis.model.core.CustomMetricSuccess;
import com.dataiku.dip.analysis.model.core.MultiCutCustomMetricSuccess;
import com.dataiku.dip.analysis.model.prediction.PredictionModelPerf;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

public class BinaryClassificationModelPerf
extends PredictionModelPerf {
    public CutData perCutData;
    public ProbaDistribData probaDistribData;
    public LiftVizData liftVizData;
    public List<List<RocVizBin>> rocVizData;
    public List<PrVizData> prVizData;
    public Map<String, PerClassPredictedProbabilityDensityData> densityData = new HashMap<String, PerClassPredictedProbabilityDensityData>();
    public List<CalibrationBin> calibrationData;
    public ThresholdIndependentMetrics tiMetrics = new ThresholdIndependentMetrics();
    public double optimalThreshold;
    public double usedThreshold;
    public double[] probaPercentiles;

    public int thresholdIndex(@Nullable Double threshold) {
        int idx;
        if (threshold == null) {
            threshold = this.usedThreshold;
        }
        int iMax = this.perCutData.cut.length - 1;
        if (threshold >= this.perCutData.cut[iMax]) {
            return iMax;
        }
        for (idx = 0; idx < iMax && threshold > this.perCutData.cut[idx]; ++idx) {
        }
        return idx;
    }

    @Override
    public Map<String, Double> getMetricMap(boolean withStdMetrics) {
        HashMap<String, Double> metricMap = new HashMap<String, Double>();
        int cutIdx = this.thresholdIndex(this.optimalThreshold);
        if (this.perCutData != null) {
            this.addMetricToMapIfPossible(metricMap, "precision", this.perCutData.precision, cutIdx);
            this.addMetricToMapIfPossible(metricMap, "recall", this.perCutData.recall, cutIdx);
            this.addMetricToMapIfPossible(metricMap, "accuracy", this.perCutData.accuracy, cutIdx);
            this.addMetricToMapIfPossible(metricMap, "f1", this.perCutData.f1, cutIdx);
            this.addMetricToMapIfPossible(metricMap, "mcc", this.perCutData.mcc, cutIdx);
            this.addMetricToMapIfPossible(metricMap, "hammingLoss", this.perCutData.hammingLoss, cutIdx);
            if (withStdMetrics) {
                this.addMetricToMapIfPossible(metricMap, "precisionstd", this.perCutData.precisionstd, cutIdx);
                this.addMetricToMapIfPossible(metricMap, "recallstd", this.perCutData.recallstd, cutIdx);
                this.addMetricToMapIfPossible(metricMap, "accuracystd", this.perCutData.accuracystd, cutIdx);
                this.addMetricToMapIfPossible(metricMap, "f1std", this.perCutData.f1std, cutIdx);
                this.addMetricToMapIfPossible(metricMap, "mccstd", this.perCutData.mccstd, cutIdx);
                this.addMetricToMapIfPossible(metricMap, "hammingLossstd", this.perCutData.hammingLossstd, cutIdx);
            }
            if (this.perCutData.customMetricsResults != null) {
                Arrays.stream(this.perCutData.customMetricsResults).filter(cr -> cr.didSucceed).map(MultiCutCustomMetricSuccess.class::cast).forEach(mccm -> {
                    this.addMetricToMapIfPossible(metricMap, mccm.metric.name, mccm.values, cutIdx);
                    if (withStdMetrics) {
                        this.addMetricToMapIfPossible(metricMap, mccm.metric.name + "std", mccm.valuesstd, cutIdx);
                    }
                });
            }
        }
        if (this.tiMetrics != null) {
            metricMap.put("auc", this.tiMetrics.auc);
            metricMap.put("logLoss", this.tiMetrics.logLoss);
            metricMap.put("lift", this.tiMetrics.lift);
            metricMap.put("calibrationLoss", this.tiMetrics.calibrationLoss);
            metricMap.put("averagePrecision", this.tiMetrics.averagePrecision);
            if (withStdMetrics) {
                metricMap.put("aucstd", this.tiMetrics.aucstd);
                metricMap.put("logLossstd", this.tiMetrics.logLossstd);
                metricMap.put("liftstd", this.tiMetrics.liftstd);
                metricMap.put("calibrationLossstd", this.tiMetrics.calibrationLossstd);
                metricMap.put("averagePrecisionstd", this.tiMetrics.averagePrecisionstd);
            }
            if (this.tiMetrics.customMetricsResults != null) {
                Arrays.stream(this.tiMetrics.customMetricsResults).filter(cr -> cr.didSucceed).map(CustomMetricSuccess.class::cast).forEach(mccm -> {
                    metricMap.put(mccm.metric.name, mccm.value);
                    if (withStdMetrics) {
                        metricMap.put(mccm.metric.name + "std", mccm.valuestd);
                    }
                });
            }
        }
        return metricMap;
    }

    private void addMetricToMapIfPossible(Map<String, Double> metricMap, String metricName, double[] perCutMetric, int cutIdx) {
        if (perCutMetric != null && perCutMetric.length > cutIdx) {
            metricMap.put(metricName, perCutMetric[cutIdx]);
        }
    }

    public static class ThresholdIndependentMetrics {
        public double auc;
        public double aucstd;
        public double logLoss;
        public double logLossstd;
        public double lift;
        public double liftstd;
        public double calibrationLoss;
        public double calibrationLossstd;
        public double averagePrecision;
        public double averagePrecisionstd;
        public double customScore;
        public double customScorestd;
        public CustomMetricResult[] customMetricsResults;
    }

    public static class CutData {
        public double[] cut;
        public float[] tp;
        public float[] tn;
        public float[] fp;
        public float[] fn;
        public double[] precision;
        public double[] recall;
        public double[] accuracy;
        public double[] f1;
        public double[] mcc;
        public double[] hammingLoss;
        public double[] customScore;
        public double[] customScorestd;
        public CustomMetricResult[] customMetricsResults;
        public PredictionModelPerf.AssertionsMetrics[] assertionsMetrics;
        public PredictionModelPerf.OverridesMetrics[] overridesMetrics;
        public double[] precisionstd;
        public double[] recallstd;
        public double[] accuracystd;
        public double[] f1std;
        public double[] mccstd;
        public double[] hammingLossstd;
    }

    public static class CalibrationBin {
        public double n;
        public double x;
        public double y;

        public CalibrationBin() {
        }

        public CalibrationBin(double n, double x, double y) {
            this.n = n;
            this.x = x;
            this.y = y;
        }
    }

    public static class PerClassPredictedProbabilityDensityData {
        public double[] actualIsThisClass;
        public double[] actualIsNotThisClass;
        public double actualIsThisClassMedian;
        public double actualIsNotThisClassMedian;
    }

    public static class PrVizData {
        public List<PrVizBin> bins;
        public double positiveRate;

        public PrVizData() {
        }

        public PrVizData(List<PrVizBin> bins, double positiveRate) {
            this.bins = bins;
            this.positiveRate = positiveRate;
        }
    }

    public static class PrVizBin {
        public double x;
        public double y;
        public double p;

        public PrVizBin() {
        }

        public PrVizBin(double threshold, double recall, double precision) {
            this.x = recall;
            this.y = precision;
            this.p = threshold;
        }
    }

    public static class RocVizBin {
        public double x;
        public double y;
        public double p;

        public RocVizBin() {
        }

        public RocVizBin(double proba, double fp, double tp) {
            this.x = fp;
            this.y = tp;
            this.p = proba;
        }
    }

    public static class LiftVizData {
        public LiftWizard wizard;
        public List<LiftBin> bins;
        public List<List<CumLift>> folds;
    }

    public static class LiftWizard {
        public double positives;
        public double total;
    }

    public static class CumLift {
        public double cum_size;
        public double cum_lift;
    }

    public static class LiftBin {
        public double percentile_idx;
        public double bin_min;
        public double bin_max;
        public double cum_size;
        public double bin_lift;
        public double cum_lift;
        public double bin_pos_prop;
    }

    public static class ProbaDistribData {
        public double[] bins;
        public double[][] probaDistribs;
    }
}

