/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.SparseVector;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.DecisionTreeClassifier;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.util.MathUtils;
import java.util.HashMap;

public class ForestClassifier
implements ProbabilisticClassifier {
    private static final long serialVersionUID = 1L;
    private final EnsemblingPolicy policy;
    private final DecisionTreeClassifier[] trees;
    private final int defaultVote;
    private final int numClasses;

    public ForestClassifier(DecisionTreeClassifier[] trees, int defaultVote) {
        this.policy = EnsemblingPolicy.VOTE;
        this.trees = trees;
        this.defaultVote = defaultVote;
        this.numClasses = ((double[])trees[0].getTerminalNode((Vector)new SparseVector(new HashMap<Integer, Double>(), (int)Integer.MAX_VALUE)).label).length;
    }

    public ForestClassifier(DecisionTreeClassifier[] trees) {
        this.policy = EnsemblingPolicy.AVERAGE;
        this.trees = trees;
        this.defaultVote = -1;
        this.numClasses = ((double[])trees[0].getTerminalNode((Vector)new SparseVector(new HashMap<Integer, Double>(), (int)Integer.MAX_VALUE)).label).length;
    }

    public EnsemblingPolicy getPolicy() {
        return this.policy;
    }

    public DecisionTreeClassifier[] getTrees() {
        return this.trees;
    }

    @Override
    public int getNumClasses() {
        return this.numClasses;
    }

    public int getDefaultVote() {
        return this.defaultVote;
    }

    private int vote(Vector v) {
        double[] votes = new double[this.numClasses];
        for (int i = 0; i < this.trees.length; ++i) {
            int n = this.trees[i].predictUnsafe(v);
            votes[n] = votes[n] + 1.0;
        }
        return MathUtils.argmax(votes);
    }

    @Override
    public Try<Integer> predict(Vector v) {
        switch (this.policy) {
            case AVERAGE: {
                return Try.success(MathUtils.argmax(this.probasUnsafe(v)));
            }
            case VOTE: {
                return Try.success(this.vote(v));
            }
        }
        throw new UnsupportedOperationException();
    }

    double[] probasUnsafe(Vector v) {
        int i;
        double[] p = new double[this.numClasses];
        for (i = 0; i < this.trees.length; ++i) {
            double[] p2 = (double[])this.trees[i].getTerminalNode((Vector)v).label;
            for (int j = 0; j < p2.length; ++j) {
                int n = j;
                p[n] = p[n] + p2[j];
            }
        }
        i = 0;
        while (i < p.length) {
            int n = i++;
            p[n] = p[n] / (double)this.trees.length;
        }
        return p;
    }

    @Override
    public Try<double[]> probabilities(Vector v) {
        return Try.success(this.probasUnsafe(v));
    }

    @Override
    public Try<double[]> decisionFunction(Vector v) {
        return Try.failure("No decision function in Random Forest Classifier");
    }

    @Override
    public boolean expectsProcessedFeaturesAsDoubles() {
        return this.trees[0].expectsProcessedFeaturesAsDoubles();
    }

    public static enum EnsemblingPolicy {
        AVERAGE,
        VOTE;

    }
}

