/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.ProbabilisticClassifier;
import com.dataiku.scoring.util.MathUtils;

public class GradientBoostingClassifier
implements ProbabilisticClassifier {
    private static final long serialVersionUID = 0L;
    private final double[] baseline;
    private final double shrinkage;
    private final DecisionTreeRegressor[][] trees;
    public final boolean expects32BitFloat;

    public GradientBoostingClassifier(double[] baseline, double shrinkage, DecisionTreeRegressor[][] trees) {
        if (baseline.length != trees[0].length) {
            throw new IllegalArgumentException("Baseline and trees had different lengths");
        }
        this.baseline = baseline;
        this.shrinkage = shrinkage;
        this.trees = trees;
        this.expects32BitFloat = trees[0][0].variant.equals((Object)DecisionTreeModel.TreeVariant.XGBOOST);
    }

    public double[] getBaseline() {
        return this.baseline;
    }

    public double getShrinkage() {
        return this.shrinkage;
    }

    public DecisionTreeRegressor[][] getTrees() {
        return this.trees;
    }

    @Override
    public int getNumClasses() {
        return this.trees[0].length == 1 ? 2 : this.trees[0].length;
    }

    @Override
    public Try<double[]> decisionFunction(Vector v) {
        double[] scores;
        if (this.expects32BitFloat) {
            if (this.getNumClasses() == 2) {
                float score = 0.0f;
                for (DecisionTreeRegressor[] tree : this.trees) {
                    score += (float)tree[0].predictUnsafe(v);
                }
                score = (float)this.baseline[0] + (float)this.shrinkage * score;
                scores = new double[]{0.0, score};
            } else {
                scores = new double[this.baseline.length];
                for (DecisionTreeRegressor[] tree : this.trees) {
                    for (int j = 0; j < this.baseline.length; ++j) {
                        scores[j] = (float)scores[j] + (float)tree[j].predictUnsafe(v);
                    }
                }
                for (int i = 0; i < this.baseline.length; ++i) {
                    scores[i] = (float)this.baseline[i] + (float)this.shrinkage * (float)scores[i];
                }
            }
        } else if (this.getNumClasses() == 2) {
            double score = 0.0;
            for (DecisionTreeRegressor[] tree : this.trees) {
                score += tree[0].predictUnsafe(v);
            }
            score = this.baseline[0] + this.shrinkage * score;
            scores = new double[]{0.0, score};
        } else {
            scores = new double[this.baseline.length];
            for (DecisionTreeRegressor[] tree : this.trees) {
                for (int j = 0; j < this.baseline.length; ++j) {
                    int n = j;
                    scores[n] = scores[n] + tree[j].predictUnsafe(v);
                }
            }
            for (int i = 0; i < scores.length; ++i) {
                scores[i] = this.baseline[i] + scores[i] * this.shrinkage;
            }
        }
        return Try.success(scores);
    }

    @Override
    public Try<double[]> probabilities(Vector v) {
        double[] scores = this.decisionFunction(v).get();
        if (this.getNumClasses() == 2) {
            if (this.expects32BitFloat) {
                float p = MathUtils.sigmoid32(scores[1]);
                return Try.success(new double[]{1.0f - p, p});
            }
            double p = MathUtils.sigmoid(scores[1]);
            return Try.success(new double[]{1.0 - p, p});
        }
        return Try.success(this.expects32BitFloat ? MathUtils.softmax32(scores) : MathUtils.softmax(scores));
    }

    @Override
    public Try<Integer> predict(Vector v) {
        Try<double[]> p = this.probabilities(v);
        if (p.isError()) {
            return Try.failure(p.getMessage());
        }
        return Try.success(MathUtils.argmax(p.get()));
    }

    @Override
    public boolean expectsProcessedFeaturesAsDoubles() {
        return this.trees[0][0].expectsProcessedFeaturesAsDoubles();
    }
}

