/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports.pmml.models.classification.binary;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.scoring.exports.pmml.XML;
import com.dataiku.dip.scoring.exports.pmml.helpers.NodeRescaler;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLSegmentations;
import com.dataiku.dip.scoring.exports.pmml.helpers.PMMLTransformationDictionary;
import com.dataiku.dip.scoring.exports.pmml.models.PMMLModel;
import com.dataiku.dip.scoring.exports.pmml.models.classification.binary.PMMLBinaryClassifier;
import com.dataiku.dip.scoring.exports.pmml.models.regression.PMMLTreeRegressor;
import com.dataiku.scoring.models.Classifier;
import com.dataiku.scoring.models.DecisionTreeModel;
import com.dataiku.scoring.models.DecisionTreeRegressor;
import com.dataiku.scoring.models.GradientBoostingClassifier;
import com.dataiku.scoring.pipelines.BinaryProbabilisticPipeline;
import com.dataiku.scoring.pipelines.ClassificationPipeline;
import com.dataiku.scoring.pipelines.Pipeline;
import java.util.ArrayList;
import java.util.List;

@XML.Named(name="MiningModel")
public class PMMLGradientBoostingBinaryClassifier
extends PMMLBinaryClassifier {
    public PMMLGradientBoostingBinaryClassifier(BinaryProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
        super(pipe, rppp, true);
        this.setModel(new PMMLBinaryTreeModel(pipe, rppp));
    }

    @Override
    public void enrichTransformationDictionaryFromModel(PMMLTransformationDictionary dictionary, String[] colNames, Classifier model) {
        dictionary.addDerivedFieldsForTrees((DecisionTreeModel<T>[][])((GradientBoostingClassifier)model).getTrees(), colNames);
    }

    public static class PMMLBinaryTreeModel
    extends PMMLBinaryClassifier.PMMLBinaryModel {
        @XML.Element
        List<PMMLSegmentations.PMMLSumSegmentation> Segmentation;
        @XML.Attribute
        final String functionName = "regression";

        public PMMLBinaryTreeModel(BinaryProbabilisticPipeline pipe, ResolvedPredictionPreprocessingParams rppp) {
            super((ClassificationPipeline)pipe, rppp);
            this.Output = PMMLModel.PMMLOutput.outputOnlyClassOnePredictedValue(pipe.getClasses());
            ArrayList<PMMLSegmentations.PMMLSumSegmentation> segmentations = new ArrayList<PMMLSegmentations.PMMLSumSegmentation>();
            PMMLSegmentations.PMMLSumSegmentation segments = new PMMLSegmentations.PMMLSumSegmentation();
            segments.Segment = new ArrayList<PMMLSegmentations.PMMLTreeSegment>();
            GradientBoostingClassifier gbt = (GradientBoostingClassifier)pipe.getModel();
            DecisionTreeRegressor[][] trees = gbt.getTrees();
            for (int j = 0; j < trees.length; ++j) {
                double baseline = j == 0 ? gbt.getBaseline()[0] : 0.0;
                NodeRescaler nodeRescaler = new NodeRescaler(baseline, gbt.getShrinkage());
                DecisionTreeRegressor originalRegressor = trees[j][0];
                DecisionTreeModel.Node<Double> rescaledTree = nodeRescaler.rescale((DecisionTreeModel.Node<Double>)originalRegressor.getRoot());
                DecisionTreeRegressor rescaledRegressor = new DecisionTreeRegressor(rescaledTree, originalRegressor.variant);
                PMMLTreeRegressor pmmlTreeRegressor = new PMMLTreeRegressor((Pipeline)pipe, rppp, rescaledRegressor, pipe.getClasses()[1], Integer.toString(j));
                segments.Segment.add(new PMMLSegmentations.PMMLTreeSegment(pmmlTreeRegressor));
            }
            segmentations.add(segments);
            this.Segmentation = segmentations;
        }
    }
}

