/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines;

import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.pipelines.AbstractClassificationPipeline;
import com.dataiku.scoring.pipelines.AbstractPipeline;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.ProbabilisticClassificationPipeline;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import java.util.Map;
import java.util.Optional;

public abstract class AbstractProbabilisticClassificationPipeline
extends AbstractClassificationPipeline<ProbabilisticClassifier, ClassificationResult>
implements ProbabilisticClassificationPipeline {
    private static final long serialVersionUID = 0L;

    AbstractProbabilisticClassificationPipeline(PreprocessingPipeline preprocessing, ProbabilisticClassifier model, String[] classes, OverridesLayerBase<ClassificationResult> overridesLayer) {
        super(preprocessing, model, classes, overridesLayer);
        for (int i = 0; i < classes.length; ++i) {
            this.columnComputers.add(new ProbabilityColumnComputer(classes, i));
        }
    }

    @Override
    public Map<String, Double> remapProbabilities(ClassificationResult result) {
        return result.remapProbabilities(this.classes);
    }

    public static class ProbabilityColumnComputer
    extends AbstractPipeline.AbstractColumnComputer<Double, String, ClassificationResult> {
        public static final String COLUMN_PREFIX = "proba_";
        private final int classIndex;

        public ProbabilityColumnComputer(String[] classes, int classIndex) {
            super(ProbabilityColumnComputer.getColumnName(classes, classIndex), Double.class);
            this.classIndex = classIndex;
        }

        @Override
        public Optional<Double> getOutputValue(ClassificationResult result) {
            if (result.isDeclined()) {
                return Optional.empty();
            }
            return Optional.of(result.getProbabilities()[this.classIndex]);
        }

        @Override
        public Class<Double> getOutputClass() {
            return Double.class;
        }

        public static String getColumnName(String[] classes, int classIndex) {
            if (classIndex >= classes.length) {
                throw new IllegalArgumentException("Illegal class index " + classIndex + " for classes length of : " + classes.length);
            }
            return COLUMN_PREFIX + classes[classIndex];
        }
    }
}

