/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.ModelDataUtilsService;
import com.dataiku.dip.analysis.ml.prediction.overrides.FormulaOverridesOutcomeComputer;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.overrides.MLOverridesParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.shaker.SampleBuilder;
import com.dataiku.dip.shaker.filter.FilterRequest;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.DataService;
import com.dataiku.dip.shaker.server.MemScriptRunner;
import com.dataiku.dip.shaker.services.TypeInferrer2;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.warnings.WarningsContext;
import com.dataiku.scoring.models.overrides.MLOverridesParamsBase;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.OverrideInfo;
import com.dataiku.scoring.pipelines.Result;
import com.dataiku.scoring.pipelines.overrides.OverridesOutcomeComputer;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PredictedDataService {
    @Autowired
    private ModelDataUtilsService modelDataUtilsService;
    TypeInferrer2 inferer = new TypeInferrer2();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.prediction");

    public static String getSampleId(FullModelId fmi, SerializedShakerScript.RefreshableStreamableSelection sampling) throws IOException {
        logger.info((Object)("Compute sample id for " + fmi.toString() + " SAMPLING=  " + JSON.json((Object)sampling)));
        Object sampleId = fmi.toString();
        MLTask task = fmi.getHeadMLTask();
        if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask pmlTask = (PredictionMLTask)task;
            if (pmlTask.predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION && fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware) {
                sampleId = (String)sampleId + "-cut=" + fmi.parseModelFile((String)"user_meta.json", ModelUserMeta.class).activeClassifierThreshold;
            }
        }
        sampleId = (String)sampleId + DigestUtils.md5Hex((String)JSON.json((Object)((Object)sampling.selection))) + sampling._refreshTrigger;
        return sampleId;
    }

    public static String getSampleId(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling) throws IOException {
        logger.info((Object)("Compute sample id for " + fmi.toString() + " SAMPLING=  " + JSON.json((Object)sampling)));
        Object sampleId = fmi.toString();
        MLTask task = fmi.getHeadMLTask();
        if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask pmlTask = (PredictionMLTask)task;
            if (pmlTask.predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION && fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware) {
                sampleId = (String)sampleId + "-cut=" + fmi.parseModelFile((String)"user_meta.json", ModelUserMeta.class).activeClassifierThreshold;
            }
        }
        sampleId = (String)sampleId + DigestUtils.md5Hex((String)JSON.json((Object)((Object)sampling.selection))) + sampling._refreshTrigger;
        return sampleId;
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedFiltered_NT(FullModelId fmi, SerializedShakerScript script, FilterRequest filters) throws Exception {
        return this.getUncachedFiltered_NT(fmi, script.explorationSampling, filters, this.getDefaultDerivedColumnsComputer(), script.sorting);
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedUnfiltered_NOTRANSACTION(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling) throws Exception {
        return this.getUncachedUnfiltered_NOTRANSACTION(fmi, sampling, this.getDefaultDerivedColumnsComputer());
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedFiltered_NT(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings explorationSampling, FilterRequest filters, DerivedColumnsComputer derivedColumnsComputer, List<SerializedShakerScript.TableSorting> sorting) throws Exception {
        MemScriptRunner.TableWithReport twr = this.getUncachedUnfiltered_NOTRANSACTION(fmi, explorationSampling, derivedColumnsComputer);
        this.modelDataUtilsService.applyFiltersAndSorts(twr, filters, sorting);
        return twr;
    }

    public DerivedColumnsComputer getDefaultDerivedColumnsComputer() {
        return new DerivedColumnsComputer(){

            @Override
            public void compute(MemTable table, MLTask task, FullModelId fmi) throws IOException {
                switch (task.taskType) {
                    case PREDICTION: {
                        PredictionMLTask pmlTask = (PredictionMLTask)task;
                        PredictedDataService.this.computePredictionDerivedColumns(table, pmlTask, fmi);
                        break;
                    }
                    case CLUSTERING: {
                        ModelUserMeta mum = fmi.parseModelFile("user_meta.json", ModelUserMeta.class);
                        ClusteringMLTask cmlTask = (ClusteringMLTask)task;
                        PredictedDataService.this.computeClusteringDerivedColumns(table, mum, cmlTask);
                        break;
                    }
                }
            }
        };
    }

    /*
     * Unable to fully structure code
     */
    public synchronized MemScriptRunner.TableWithReport getUncachedUnfiltered_NOTRANSACTION(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling, DerivedColumnsComputer computeDerivedPredictedColumns) throws Exception {
        state = FutureProgress.pushAutoCloseableState((String)"Computing", (double)4.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);
        try {
            sampleId = PredictedDataService.getSampleId(fmi, sampling);
            PredictedDataService.logger.info((Object)("PredictedData disk sample id is " + sampleId));
            task = fmi.getHeadMLTask();
            splitDesc = fmi.getSplitDesc();
            before = System.currentTimeMillis();
            ret = new MemScriptRunner.TableWithReport();
            warningsContext = new WarningsContext();
            sampleToRead = SampleBuilder.getPredictedDataSampleMeta(fmi, sampleId);
            FutureProgress.updateState((double)1.0);
            if (sampleToRead != null) {
                PredictedDataService.logger.info((Object)"Disk sample cache hit");
            } else {
                PredictedDataService.logger.info((Object)"Disk sample cache miss");
                SampleBuilder.clearPredictedDataSamples(fmi);
                buildingState = FutureProgress.pushAutoCloseableState((String)"Building sample");
                try {
                    switch (2.$SwitchMap$com$dataiku$dip$analysis$model$MLTask$MLTaskType[task.taskType.ordinal()]) {
                        case 2: {
                            SampleBuilder.buildPredictedSampleForClustering(fmi, sampleId, splitDesc);
                            ** break;
lbl23:
                            // 1 sources

                            break;
                        }
                        case 1: {
                            if (task instanceof PredictionMLTask.TimeseriesForecastingMLTask) {
                                SampleBuilder.buildPredictedSampleForTimeseriesForecast(fmi, sampleId);
                                ** break;
lbl28:
                                // 1 sources

                            } else {
                                SampleBuilder.buildPredictedSampleForPrediction(fmi, sampleId, splitDesc);
                            }
                            break;
                        }
                        ** default:
lbl32:
                        // 1 sources

                        break;
                    }
                }
                finally {
                    if (buildingState != null) {
                        buildingState.close();
                    }
                }
                sampleToRead = SampleBuilder.getPredictedDataSampleMeta(fmi, sampleId);
            }
            if (!PredictedDataService.$assertionsDisabled && sampleToRead == null) {
                throw new AssertionError();
            }
            FutureProgress.updateState((double)2.0);
            PredictedDataService.logger.info((Object)("Opening sample " + sampleId));
            ret.usedSample = sampleToRead;
            readingState = FutureProgress.pushAutoCloseableState((String)"Reading sample");
            try {
                ret.table = SampleBuilder.readPredictedSample(fmi, sampleToRead.id);
                computeDerivedPredictedColumns.compute(ret.table, task, fmi);
                ret.initialRows = ret.table.nrows();
                ret.initialCols = ret.table.ncols();
                PredictedDataService.logger.info((Object)("Reading sample done, read " + ret.table.nrows() + " rows"));
            }
            finally {
                if (readingState != null) {
                    readingState.close();
                }
            }
            FutureProgress.updateState((double)3.0);
            ret.table.compact();
            FutureProgress.updateState((double)4.0);
            PredictedDataService.logger.info((Object)("Serialized warnings " + String.valueOf(warningsContext) + " -> " + JSON.json((Object)warningsContext.getOutput())));
            ret.warnings = warningsContext.getOutput();
            inferringState = FutureProgress.pushAutoCloseableState((String)"Detecting types");
            try {
                infererCacheKey = fmi.toString() + JSON.json((Object)sampling);
                this.inferer.processFullAuto(infererCacheKey, ret.table);
            }
            finally {
                if (inferringState != null) {
                    inferringState.close();
                }
            }
            FutureProgress.updateState((double)5.0);
            inferDone = System.currentTimeMillis();
            PredictedDataService.logger.info((Object)("PredictedDataService done time =  " + (inferDone - before)));
            var15_23 = ret;
            return var15_23;
        }
        finally {
            if (state != null) {
                state.close();
            }
        }
    }

    public DataService.ColumnDetailedAnalysis getDetailedColumnAnalysis(FullModelId fmi, SerializedShakerScript ss, String column, int alphanumMaxResults) throws Exception {
        MemTable table = this.getUncachedUnfiltered_NOTRANSACTION((FullModelId)fmi, (SerializedShakerScript.ShakerExplorationSampleSettings)ss.explorationSampling).table;
        return this.modelDataUtilsService.getDetailedColumnAnalysis(table, ss, column, alphanumMaxResults);
    }

    private void computePredictionDerivedColumns(MemTable table, PredictionMLTask trainedWithTask, FullModelId fmi) throws IOException {
        switch (trainedWithTask.predictionType) {
            case BINARY_CLASSIFICATION: {
                boolean withOverrides;
                String negativeValue;
                String positiveValue;
                logger.info((Object)"Update prediction !");
                PredictionMLTask.ClassicalPredictionMLTask headTask = (PredictionMLTask.ClassicalPredictionMLTask)fmi.getHeadMLTask();
                if (!fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware) break;
                FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer overridesOutcomeComputer = this.initOutcomeComputer(fmi, table);
                ModelUserMeta mum = fmi.parseModelFile("user_meta.json", ModelUserMeta.class);
                if (FullModelId.Type.ANALYSIS.equals((Object)fmi.getType())) {
                    PredictionPreprocessingParams predictionPreprocessingParams = trainedWithTask.getPreprocessingParams();
                    positiveValue = predictionPreprocessingParams.getSourceValueForMapped(1);
                    negativeValue = predictionPreprocessingParams.getSourceValueForMapped(0);
                } else if (FullModelId.Type.SAVED.equals((Object)fmi.getType())) {
                    ResolvedPredictionPreprocessingParams resolvedPredictionPreprocessingParams = (ResolvedPredictionPreprocessingParams)fmi.getResolvedPreprocessingParams();
                    positiveValue = resolvedPredictionPreprocessingParams.getSourceValueForMapped(1);
                    negativeValue = resolvedPredictionPreprocessingParams.getSourceValueForMapped(0);
                } else {
                    throw new IllegalArgumentException("Invalid FMI type: " + fmi.getType().name());
                }
                String positiveProbaColName = "proba_" + positiveValue;
                String negativeProbaColName = "proba_" + negativeValue;
                MemColumn positiveProbaCol = table.column(positiveProbaColName);
                MemColumn negativeProbaCol = table.column(negativeProbaColName);
                MemColumn targetCol = table.column(trainedWithTask.targetVariable);
                MemColumn predictionCol = table.column("prediction");
                MemColumn correctCol = table.column("prediction_correct");
                MemColumn cmgCol = table.column("costmatrix_gain");
                MemColumn overrideCol = null;
                MemColumn uncertaintyCol = null;
                boolean bl = withOverrides = overridesOutcomeComputer != null;
                if (withOverrides) {
                    overrideCol = table.column("override");
                    uncertaintyCol = table.column("prediction_uncertainty");
                }
                for (MemRow row : table.rows) {
                    double rawProba = row.getAsDoubleOrNaN(positiveProbaCol);
                    if (Double.isNaN(rawProba)) continue;
                    boolean predictedPositive = rawProba > mum.activeClassifierThreshold;
                    row.put((Column)predictionCol, predictedPositive ? positiveValue : negativeValue);
                    if (withOverrides) {
                        row.put((Column)uncertaintyCol, 1.0 - Math.max(rawProba, 1.0 - rawProba));
                        OverridesOutcomeComputer.OutcomeCandidate newCandidate = overridesOutcomeComputer.getOutcomeCandidate(row);
                        double[] probas = new double[]{row.getAsDoubleOrNaN(negativeProbaCol), row.getAsDoubleOrNaN(positiveProbaCol)};
                        ClassificationResult rawResult = new ClassificationResult(row.get(predictionCol), probas, null);
                        if (newCandidate != null) {
                            OverrideInfo overrideInfo;
                            MLOverridesParamsBase.MLOverride.Outcome newOutcome = newCandidate.outcome;
                            String[] classes = new String[]{negativeValue, positiveValue};
                            ClassificationResult.RawResult r = new ClassificationResult.RawResult(rawResult, classes);
                            if (newOutcome.type == MLOverridesParamsBase.MLOverride.Outcome.Type.DECLINED) {
                                overrideInfo = OverrideInfo.declined((String)newCandidate.overrideName, (Result.RawResult)r);
                                row.put((Column)predictionCol, null);
                                row.put((Column)positiveProbaCol, null);
                                row.put((Column)negativeProbaCol, null);
                                row.put((Column)overrideCol, overrideInfo.toJson());
                                continue;
                            }
                            overrideInfo = new OverrideInfo(newCandidate.overrideName, Boolean.valueOf(!Objects.equals(newOutcome.category, rawResult.getPrediction())), (Result.RawResult)r);
                            row.put((Column)overrideCol, overrideInfo.toJson());
                            row.put((Column)predictionCol, newOutcome.category);
                            if (positiveValue.equals(newOutcome.category)) {
                                row.put((Column)positiveProbaCol, 1.0);
                                row.put((Column)negativeProbaCol, 0.0);
                                predictedPositive = true;
                            } else {
                                row.put((Column)positiveProbaCol, 0.0);
                                row.put((Column)negativeProbaCol, 1.0);
                                predictedPositive = false;
                            }
                        } else {
                            row.put((Column)overrideCol, OverrideInfo.noMatch().toJson());
                        }
                    }
                    boolean correct = false;
                    String targetVal = row.get(targetCol);
                    if (!StringUtils.isBlank((String)targetVal)) {
                        correct = targetVal.equals(predictedPositive ? positiveValue : negativeValue);
                    }
                    row.put((Column)correctCol, correct);
                    double cmg = 0.0;
                    if (predictedPositive && correct) {
                        cmg = headTask.modeling.metrics.costMatrixWeights.tpGain;
                    }
                    if (predictedPositive && !correct) {
                        cmg = headTask.modeling.metrics.costMatrixWeights.fpGain;
                    }
                    if (!predictedPositive && correct) {
                        cmg = headTask.modeling.metrics.costMatrixWeights.tnGain;
                    }
                    if (!predictedPositive && !correct) {
                        cmg = headTask.modeling.metrics.costMatrixWeights.fnGain;
                    }
                    row.put((Column)cmgCol, cmg);
                }
                if (!withOverrides) break;
                table.deleteColumn("prediction_uncertainty");
                break;
            }
            case MULTICLASS: {
                Map<String, Integer> forwardMap;
                logger.info((Object)"Update prediction !");
                if (FullModelId.Type.ANALYSIS.equals((Object)fmi.getType())) {
                    PredictionPreprocessingParams predictionPreprocessingParams = trainedWithTask.getPreprocessingParams();
                    forwardMap = predictionPreprocessingParams.getTargetForwardMap();
                } else if (FullModelId.Type.SAVED.equals((Object)fmi.getType())) {
                    ResolvedPredictionPreprocessingParams resolvedPredictionPreprocessingParams = (ResolvedPredictionPreprocessingParams)fmi.getResolvedPreprocessingParams();
                    forwardMap = resolvedPredictionPreprocessingParams.getTargetForwardMap();
                } else {
                    throw new IllegalArgumentException("Invalid FMI type: " + fmi.getType().name());
                }
                MemColumn targetCol = table.column(trainedWithTask.targetVariable);
                MemColumn predictionCol = table.column("prediction");
                MemColumn actualIdCol = table.column("actual_class_id");
                MemColumn predictedIdCol = table.column("predicted_class_id");
                MemColumn correctCol = table.column("prediction_correct");
                for (MemRow row : table.rows) {
                    Integer predictedIdVal;
                    String actualVal = row.get(targetCol);
                    String predictedVal = row.get(predictionCol);
                    Integer actualIdVal = forwardMap.get(actualVal);
                    if (actualIdVal != null) {
                        row.put((Column)actualIdCol, actualIdVal);
                    }
                    if ((predictedIdVal = forwardMap.get(predictedVal)) != null) {
                        row.put((Column)predictedIdCol, predictedIdVal);
                    }
                    if (actualIdVal == null) {
                        row.put((Column)correctCol, predictedIdVal == null);
                        continue;
                    }
                    row.put((Column)correctCol, actualIdVal.equals(predictedIdVal));
                }
                break;
            }
        }
        if (table.hasNonDeletedColumn("fold_id")) {
            table.moveAtEnd("fold_id");
        }
    }

    private FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer initOutcomeComputer(FullModelId fmi, MemTable table) {
        try {
            MLOverridesParams overridesParams = fmi.getOverridesParams();
            if (overridesParams.hasOverrides()) {
                FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer outcomeComputer = new FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer(overridesParams, table);
                outcomeComputer.init();
                return outcomeComputer;
            }
        }
        catch (IOException e) {
            logger.warn((Object)"Could not load overrides, ignoring them", (Throwable)e);
        }
        return null;
    }

    private void computeClusteringDerivedColumns(MemTable table, ModelUserMeta mum, ClusteringMLTask cmlTask) {
        MemColumn cluster_labels = table.column("cluster_labels");
        if (cmlTask.modeling.isolation_forest.enabled) {
            return;
        }
        MemColumn cluster_id = table.column("cluster_id");
        for (MemRow row : table.rows) {
            ModelUserMeta.ClusterMeta meta;
            String cl = row.get(cluster_labels);
            if (StringUtils.isBlank((String)cl) || (meta = mum.clusterMetas.get(cl)) == null) continue;
            row.put((Column)cluster_labels, meta.name);
            if (cl.equals("cluster_outliers")) {
                row.put((Column)cluster_id, "-1");
                continue;
            }
            if (!cl.startsWith("cluster_")) continue;
            row.put((Column)cluster_id, cl.replace("cluster_", ""));
        }
    }

    static interface DerivedColumnsComputer {
        public void compute(MemTable var1, MLTask var2, FullModelId var3) throws IOException;
    }
}

