/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.labeling.score;

import com.dataiku.dip.labeling.Annotation;
import com.dataiku.dip.labeling.LabelingAnswer;
import com.dataiku.dip.labeling.LabelingClassicalMetrics;
import com.dataiku.dip.labeling.VerifiedLabel;
import com.dataiku.dip.labeling.classification.ClassificationAnnotation;
import com.dataiku.dip.labeling.score.LabelingClassicalScore;
import com.dataiku.dip.labeling.score.LabelingScoreComputer;
import java.util.List;

public abstract class LabelingClassicalScoreComputer<P extends VerifiedLabel.AnnotationPair>
extends LabelingScoreComputer<LabelingClassicalScore, LabelingClassicalMetrics> {
    public boolean isFalsePositive(P annotationPair) {
        return ((VerifiedLabel.AnnotationPair)annotationPair).verifiedCategory == null;
    }

    public abstract List<P> getPairsFromVerifiedLabel(VerifiedLabel var1);

    @Override
    public LabelingClassicalScore score(LabelingAnswer answer) {
        VerifiedLabel verifiedLabel = answer.verifiedLabel;
        LabelingClassicalScore score = new LabelingClassicalScore();
        if (verifiedLabel == null) {
            return score;
        }
        for (VerifiedLabel.AnnotationPair pair : this.getPairsFromVerifiedLabel(verifiedLabel)) {
            if (pair.annotationIdx == null) {
                ++score.fn;
                continue;
            }
            if (this.isFalsePositive(pair)) {
                ++score.fp;
                continue;
            }
            Annotation annotation = (Annotation)answer.label.annotations.get(pair.annotationIdx);
            if (!pair.verifiedCategory.equals(((ClassificationAnnotation)annotation).category)) {
                ++score.fp;
                continue;
            }
            ++score.tp;
        }
        return score;
    }

    @Override
    public LabelingClassicalMetrics getMetrics(LabelingClassicalScore score) {
        LabelingClassicalMetrics metrics = new LabelingClassicalMetrics();
        if (score.tp + score.fp > 0) {
            metrics.precision = (double)score.tp * 1.0 / (double)(score.tp + score.fp);
        }
        if (score.tp + score.fn > 0) {
            metrics.recall = (double)score.tp * 1.0 / (double)(score.tp + score.fn);
        }
        if (metrics.precision != null && metrics.recall != null && (metrics.precision != 0.0 || metrics.recall != 0.0)) {
            metrics.f1Score = 2.0 * metrics.precision * metrics.recall / (metrics.precision + metrics.recall);
        }
        return metrics;
    }

    @Override
    public void incrementScore(LabelingClassicalScore originalScore, LabelingClassicalScore newScore) {
        originalScore.fn += newScore.fn;
        originalScore.fp += newScore.fp;
        originalScore.tp += newScore.tp;
    }

    @Override
    public LabelingClassicalScore initScore() {
        return new LabelingClassicalScore();
    }
}

