/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.docgen.extractor;

import com.dataiku.dip.analysis.docgen.extractor.ResultsMLOverridesBaseTableExtractor;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.MulticlassModelPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionModelPerf;
import com.dataiku.dip.analysis.model.prediction.RegressionModelPerf;
import com.dataiku.dip.dataflow.exec.filter.FilterDesc;
import com.dataiku.scoring.models.overrides.MLOverridesParamsBase;
import java.util.ArrayList;
import java.util.List;

public class ResultsMLOverridesMetricsTableExtractor
extends ResultsMLOverridesBaseTableExtractor {
    @Override
    protected List<List<String>> buildTableForNonEmptyOverrides(ClassicalPredictionModelDetails modelDetails) {
        PredictionModelPerf.OverridesMetrics overridesMetrics = this.getOverridesMetrics(modelDetails);
        ArrayList<List<String>> table = new ArrayList<List<String>>();
        table.add(this.buildHeaders());
        int nbAlreadyMatchedRows = 0;
        for (int i = 0; i < modelDetails.overridesParams.overrides.size(); ++i) {
            MLOverridesParamsBase.MLOverride override = (MLOverridesParamsBase.MLOverride)modelDetails.overridesParams.overrides.get(i);
            PredictionModelPerf.OverrideMetrics overrideMetrics = overridesMetrics.perOverride.get(i);
            table.add(this.buildRow(overrideMetrics, (MLOverridesParamsBase.MLOverride<FilterDesc>)override, nbAlreadyMatchedRows));
            nbAlreadyMatchedRows += overrideMetrics.nbMatchingRows;
        }
        return table;
    }

    private PredictionModelPerf.OverridesMetrics getOverridesMetrics(ClassicalPredictionModelDetails modelDetails) {
        switch (modelDetails.coreParams.prediction_type) {
            case BINARY_CLASSIFICATION: {
                double activeClassifierThreshold = modelDetails.userMeta.activeClassifierThreshold;
                BinaryClassificationModelPerf modelPerf = (BinaryClassificationModelPerf)modelDetails.perf;
                BinaryClassificationModelPerf.CutData perCutData = modelPerf.perCutData;
                return perCutData.overridesMetrics[modelPerf.thresholdIndex(activeClassifierThreshold)];
            }
            case MULTICLASS: {
                return ((MulticlassModelPerf)modelDetails.perf).metrics.overridesMetrics;
            }
            case REGRESSION: {
                return ((RegressionModelPerf)modelDetails.perf).metrics.overridesMetrics;
            }
        }
        throw new IllegalArgumentException(modelDetails.coreParams.prediction_type.name() + " not handled for document generation");
    }

    private List<String> buildRow(PredictionModelPerf.OverrideMetrics overrideMetrics, MLOverridesParamsBase.MLOverride<FilterDesc> override, int nbAlreadyMatchedRows) {
        ArrayList<String> line = new ArrayList<String>();
        line.add(override.name);
        line.add(override.outcome.toString());
        line.add(String.valueOf(nbAlreadyMatchedRows));
        line.add(String.valueOf(overrideMetrics.nbMatchingRows));
        line.add(String.valueOf(overrideMetrics.nbChangedRows));
        return line;
    }

    private List<String> buildHeaders() {
        ArrayList<String> headers = new ArrayList<String>();
        headers.add("Name");
        headers.add("Enforce");
        headers.add("Already matched rows");
        headers.add("Matching rows");
        headers.add("Changed rows");
        return headers;
    }
}

