/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.scoring.pipelines.overrides;

import com.dataiku.scoring.models.overrides.MLOverridesParamsBase;
import com.dataiku.scoring.pipelines.OverrideInfo;
import com.dataiku.scoring.pipelines.RegressionResult;
import com.dataiku.scoring.pipelines.overrides.OverridesLayerBase;
import com.dataiku.scoring.pipelines.overrides.OverridesOutcomeComputer;
import com.dataiku.scoring.util.RawObservation;

public class RegressionOverridesLayer
extends OverridesLayerBase<RegressionResult> {
    public static final String PREDICTION_INTERVAL_LOWER_COL = "prediction_interval_lower";
    public static final String PREDICTION_INTERVAL_UPPER_COL = "prediction_interval_upper";
    public static final String PREDICTION_INTERVAL_SIZE_COL = "prediction_interval_size";
    public static final String PREDICTION_INTERVAL_RELATIVE_SIZE_COL = "prediction_interval_relative_size";

    public RegressionOverridesLayer(OverridesOutcomeComputer<RawObservation> outcomeComputer) {
        super(outcomeComputer);
    }

    @Override
    void prepareRowForOverride(RawObservation originalRow, RegressionResult rawResult) {
        Double prediction = (Double)rawResult.getPrediction();
        originalRow.put("prediction", prediction);
        rawResult.getPredictionInterval().ifPresent(interval -> {
            originalRow.put(PREDICTION_INTERVAL_LOWER_COL, interval.lower);
            originalRow.put(PREDICTION_INTERVAL_UPPER_COL, interval.upper);
            originalRow.put(PREDICTION_INTERVAL_SIZE_COL, interval.size);
            Double intervalRelativeSize = prediction == null || prediction.isNaN() || prediction == 0.0 ? null : Double.valueOf(interval.size / prediction);
            originalRow.put(PREDICTION_INTERVAL_RELATIVE_SIZE_COL, intervalRelativeSize);
        });
    }

    @Override
    RegressionResult applyOverride(OverridesOutcomeComputer.OutcomeCandidate<RawObservation> candidate, RegressionResult rawResult) {
        RegressionResult result;
        MLOverridesParamsBase.MLOverride.Outcome outcome = candidate.outcome;
        RegressionResult.RawResult rawResultInfo = new RegressionResult.RawResult(rawResult);
        switch (outcome.type) {
            case INTERVAL: {
                Double newPred = Math.max(outcome.minValue, Math.min(outcome.maxValue, (Double)rawResult.getPrediction()));
                result = rawResult.withNewPrediction(newPred);
                result.setOverrideInfo(new OverrideInfo(candidate.overrideName, !((Double)rawResult.getPrediction()).equals(result.getPrediction()), rawResultInfo));
                break;
            }
            case DECLINED: {
                result = RegressionResult.declined();
                result.setOverrideInfo(OverrideInfo.declined(candidate.overrideName, rawResultInfo));
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported Outcome Type (" + String.valueOf((Object)outcome.type) + "). Regression only supports interval or declined override");
            }
        }
        return result;
    }
}

