/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.docgen.extractor;

import com.dataiku.dip.analysis.docgen.extractor.ModelExtractor;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.MulticlassModelPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelPerf;
import com.dataiku.dip.analysis.model.prediction.RegressionModelPerf;
import com.dataiku.dip.analysis.model.prediction.assertions.MLAssertionsParams;
import com.dataiku.dip.dataflow.exec.filter.FilterDescUtils;
import com.jayway.jsonpath.DocumentContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ResultsMLAssertionsTableExtractor
implements ModelExtractor<List<List<String>>> {
    public static final String NO_ASSERTIONS_DEFINED = "No assertions defined in the settings";

    @Override
    public List<List<String>> extract(DocumentContext documentContext, ModelDetailsBase model) throws IOException {
        if (!(model instanceof ClassicalPredictionModelDetails)) {
            throw new IOException("Ml assertions are only defined for prediction tasks");
        }
        ClassicalPredictionModelDetails modelDetails = (ClassicalPredictionModelDetails)model;
        ArrayList<List<String>> assertionsWithMetricsTable = new ArrayList<List<String>>();
        assertionsWithMetricsTable.add(this.buildHeaders(modelDetails));
        if (modelDetails.assertionsParams == null || modelDetails.assertionsParams.assertions.isEmpty()) {
            assertionsWithMetricsTable.add(new ArrayList<String>(Arrays.asList("", NO_ASSERTIONS_DEFINED)));
        } else {
            PredictionModelPerf.AssertionsMetrics assertionsMetrics = this.getAssertionsMetrics(modelDetails);
            for (MLAssertionsParams.MLAssertion assertion : modelDetails.assertionsParams.assertions) {
                ArrayList<String> line = new ArrayList<String>();
                this.addAssertionParams(assertion, line);
                this.addAssertionMetrics(assertionsMetrics, line, assertion.name);
                assertionsWithMetricsTable.add(line);
            }
        }
        return assertionsWithMetricsTable;
    }

    private PredictionModelPerf.AssertionsMetrics getAssertionsMetrics(ClassicalPredictionModelDetails modelDetails) {
        switch (modelDetails.coreParams.prediction_type) {
            case BINARY_CLASSIFICATION: {
                PredictionModelPerf.AssertionsMetrics assertionsMetrics = null;
                double activeClassifierThreshold = modelDetails.userMeta.activeClassifierThreshold;
                BinaryClassificationModelPerf modelPerf = (BinaryClassificationModelPerf)modelDetails.perf;
                BinaryClassificationModelPerf.CutData perCutData = modelPerf.perCutData;
                if (perCutData.assertionsMetrics != null && modelPerf.thresholdIndex(activeClassifierThreshold) < perCutData.assertionsMetrics.length) {
                    assertionsMetrics = perCutData.assertionsMetrics[modelPerf.thresholdIndex(activeClassifierThreshold)];
                }
                return assertionsMetrics;
            }
            case MULTICLASS: {
                return ((MulticlassModelPerf)modelDetails.perf).metrics.assertionsMetrics;
            }
            case REGRESSION: {
                return ((RegressionModelPerf)modelDetails.perf).metrics.assertionsMetrics;
            }
        }
        throw new IllegalArgumentException(modelDetails.coreParams.prediction_type.name() + "not handled for document generation");
    }

    private void addAssertionMetrics(PredictionModelPerf.AssertionsMetrics assertionsMetrics, List<String> line, String assertionName) throws IOException {
        if (assertionsMetrics != null) {
            PredictionModelPerf.AssertionMetrics assertionMetrics = assertionsMetrics.getAssertionMetrics(assertionName);
            if (assertionMetrics == null) {
                throw new IOException("Assertion: " + assertionName + "was not found in assertionsMetrics");
            }
            line.add(String.valueOf(assertionMetrics.nbMatchingRows));
            line.add(String.valueOf(assertionMetrics.nbDroppedRows));
            if (assertionMetrics.validRatio != null) {
                line.add((double)Math.round(10000.0 * assertionMetrics.validRatio) / 100.0 + "%");
            } else {
                line.add("-");
            }
            if (assertionMetrics.result == null) {
                line.add("-");
            } else if (assertionMetrics.result.booleanValue()) {
                line.add("Pass");
            } else {
                line.add("Fail");
            }
        } else {
            line.add("No assertions metrics available");
        }
    }

    private void addAssertionParams(MLAssertionsParams.MLAssertion assertion, List<String> line) {
        line.add(assertion.name);
        line.add(FilterDescUtils.getFilterRepr(assertion.filter));
        if (assertion.assertionCondition.expectedClass != null) {
            line.add(assertion.assertionCondition.expectedClass);
        } else if (assertion.assertionCondition.expectedMinValue != null && assertion.assertionCondition.expectedMaxValue != null) {
            line.add(assertion.assertionCondition.expectedMinValue + "-" + assertion.assertionCondition.expectedMaxValue);
        } else {
            line.add("No assertion condition defined");
        }
        line.add(Math.round(100.0 * assertion.assertionCondition.expectedValidRatio) + "%");
    }

    private List<String> buildHeaders(ClassicalPredictionModelDetails modelDetails) {
        ArrayList<String> headers = new ArrayList<String>();
        headers.add("Name");
        headers.add("Criteria");
        PredictionMLTask.PredictionType prediction_type = modelDetails.coreParams.prediction_type;
        switch (prediction_type) {
            case BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                headers.add("Expected class");
                break;
            }
            case REGRESSION: {
                headers.add("Expected range");
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)prediction_type));
            }
        }
        headers.add("Expected valid ratio");
        headers.add("Rows matching criteria");
        headers.add("Rows dropped by the model");
        headers.add("Valid ratio");
        headers.add("Result");
        return headers;
    }
}

