/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.models;

import com.dataiku.scoring.Try;
import com.dataiku.scoring.linalg.Vector;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.Regressor;
import com.dataiku.scoring.util.MathUtils;

public class GradientBoostingRegressor
implements Regressor {
    private static final long serialVersionUID = 0L;
    private final double baseline;
    private final double shrinkage;
    private final DecisionTreeRegressor[] trees;
    private final int size;
    private final boolean gammaRegression;
    private final boolean logisticRegression;
    private final boolean expectsSinglePrecisionOperations;

    public GradientBoostingRegressor(double baseline, double shrinkage, DecisionTreeRegressor[] trees, boolean gammaRegression, boolean logisticRegression) {
        this.baseline = baseline;
        this.shrinkage = shrinkage;
        this.trees = trees;
        this.size = trees.length;
        this.gammaRegression = gammaRegression;
        this.logisticRegression = logisticRegression;
        this.expectsSinglePrecisionOperations = trees[0].variant.equals((Object)DecisionTreeModel.TreeVariant.XGBOOST);
    }

    public double getBaseline() {
        return this.baseline;
    }

    public double getShrinkage() {
        return this.shrinkage;
    }

    public DecisionTreeRegressor[] getTrees() {
        return this.trees;
    }

    @Override
    public Try<Double> predict(Vector v) {
        if (this.expectsSinglePrecisionOperations) {
            float p = 0.0f;
            for (int i = 0; i < this.size; ++i) {
                p += (float)this.trees[i].predictUnsafe(v);
            }
            if (this.gammaRegression) {
                return Try.success(Double.valueOf((float)this.baseline * (float)Math.exp((float)this.shrinkage * p)));
            }
            if (this.logisticRegression) {
                return Try.success(Double.valueOf(MathUtils.sigmoid32((float)this.shrinkage * p)));
            }
            return Try.success(Double.valueOf((float)this.baseline + (float)this.shrinkage * p));
        }
        double p = 0.0;
        for (int i = 0; i < this.size; ++i) {
            p += this.trees[i].predictUnsafe(v);
        }
        if (this.gammaRegression) {
            return Try.success(this.baseline * Math.exp(this.shrinkage * p));
        }
        return Try.success(this.baseline + this.shrinkage * p);
    }

    @Override
    public boolean expectsProcessedFeaturesAsDoubles() {
        return this.trees[0].expectsProcessedFeaturesAsDoubles();
    }
}

