/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.guess;

import com.dataiku.dip.analysis.ml.prediction.guess.PredictionGuesser;
import com.dataiku.dip.analysis.ml.shared.FeatureGuessUtils;
import com.dataiku.dip.analysis.model.GuessStatus;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubMetricParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.pivot.backend.model.FilterFacet;
import com.dataiku.dip.shaker.facet.BoundingBoxFaceter;
import com.dataiku.dip.shaker.types.JSONArrayMeaning;
import com.dataiku.dip.utils.DKULogger;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;

public class DeepHubPredictionGuesser
extends PredictionGuesser<PredictionMLTask.DeepHubPredictionMLTask> {
    public static final int FAILED_PATH_THRESHOLD_PERCENT = 50;
    private boolean foundImagePathsColumn;
    private long rowsParsingFailuresPercent;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis");

    public DeepHubPredictionGuesser(PredictionMLTask.DeepHubPredictionMLTask task, MemTable table) {
        super(task, table);
    }

    @Override
    public boolean canChangePredictionType() {
        return false;
    }

    @Override
    public void changeTargetNoReguess(String previousTarget, @Nullable GuessStatus previousGuessStatus) {
        super.changeTargetNoReguess(previousTarget, previousGuessStatus);
        if (Objects.equals(((PredictionMLTask.DeepHubPredictionMLTask)this.task).pathColumn, ((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable)) {
            ((PredictionMLTask.DeepHubPredictionMLTask)this.task).pathColumn = null;
        }
        this.setTargetRemapping(true);
    }

    @Override
    protected void guessPredictionType(MemColumn targetColumn) {
    }

    @Override
    protected void checkTargetColumn(boolean throwException) {
        super.checkTargetColumn(throwException);
        MemColumn targetCol = this.table.column(((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable);
        if (targetCol != null && targetCol.selectedType != null) {
            boolean isTargetJSONArray;
            boolean bl = isTargetJSONArray = targetCol.selectedType.type.getClass() == JSONArrayMeaning.class;
            if (PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION.equals((Object)((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType) && !isTargetJSONArray) {
                this.throwOrAddMessage(String.format("Column '%s' must be of type JSONArray", ((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable), throwException);
            }
        }
    }

    @Override
    public GuessStatus checkStatus() {
        if (!this.foundImagePathsColumn) {
            this.messages.add("Could not find a column that looks like image paths. Please select it manually, or modify your data accordingly");
        }
        if (PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION.equals((Object)((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType) && this.rowsParsingFailuresPercent > 50L) {
            String message = this.rowsParsingFailuresPercent + "% rows from target column ('" + this.table.column(((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable).getName() + "') could not be parsed as an object detection target, did you select the right column ?";
            this.messages.add(message);
        }
        GuessStatus guessStatus = super.checkStatus();
        guessStatus.foundImagePathsColumn = this.foundImagePathsColumn;
        guessStatus.rowsParsingFailuresPercent = this.rowsParsingFailuresPercent;
        return guessStatus;
    }

    @Override
    public void retrievePreviousGuessStatusBooleans(@Nullable GuessStatus previousGuessStatus) {
        if (previousGuessStatus != null) {
            this.foundImagePathsColumn = previousGuessStatus.foundImagePathsColumn == null || previousGuessStatus.foundImagePathsColumn != false;
            this.rowsParsingFailuresPercent = previousGuessStatus.rowsParsingFailuresPercent == null ? 0L : previousGuessStatus.rowsParsingFailuresPercent;
        }
    }

    @Override
    protected void guessAllSettingsWithFixedPredictionType(boolean throwException) {
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).splitParams = SplitParams.buildStd();
        this.checkAllFixableSettings(throwException);
        assert (((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType != null);
        logger.info((Object)"Guessing deephub algorithms");
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling = DeepHubPreTrainModelingParams.build(((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType);
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling.modelOptimizationSplitParams = new DeepHubPreTrainModelingParams.DeepHubModelOptimizationSplitParams();
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling.metrics = new DeepHubMetricParams();
        switch (((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType) {
            case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling.metrics.evaluationMetric = DeepHubMetricParams.EvaluationMetric.AVERAGE_PRECISION_IOU50;
                ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling.metrics.confidenceScoreThresholdOptimMetric = DeepHubMetricParams.ConfidenceScoreThresholdOptimizationMetric.F1;
                break;
            }
            case DEEP_HUB_IMAGE_CLASSIFICATION: {
                ((PredictionMLTask.DeepHubPredictionMLTask)this.task).modeling.metrics.evaluationMetric = DeepHubMetricParams.EvaluationMetric.ROC_AUC;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType));
            }
        }
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).preprocessing = new PredictionPreprocessingParams();
        this.guessImagePath();
        ((PredictionMLTask.DeepHubPredictionMLTask)this.task).preprocessing.per_feature = new HashMap();
        for (String name : this.table.columns.keySet()) {
            this.guessFeature(name);
        }
        if (((PredictionMLTask.DeepHubPredictionMLTask)this.task).preprocessing.target_remapping.isEmpty()) {
            this.setTargetRemapping(throwException);
        }
    }

    private void guessImagePath() {
        ArrayList<PathStats> stats = new ArrayList<PathStats>();
        for (MemColumn column : this.table.columns.values()) {
            if (column.getName().equals(((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable)) continue;
            logger.info((Object)("Guessing feature image path on " + column.getName()));
            PathStats stat = new PathStats(this.table, column, PathStats.KNOWN_IMAGES_EXTENSIONS);
            stats.add(stat);
        }
        if (stats.size() == 0) {
            throw new IllegalArgumentException("Couldn't guess image paths");
        }
        stats.sort(Comparator.comparingDouble(o -> -o.pathCount));
        PathStats maxStats = (PathStats)stats.get(0);
        boolean bl = this.foundImagePathsColumn = 100.0 * (double)maxStats.failedCount / (double)this.table.nrows() <= 50.0;
        if (this.foundImagePathsColumn) {
            ((PredictionMLTask.DeepHubPredictionMLTask)this.task).pathColumn = maxStats.column.getName();
        }
    }

    private void setTargetRemapping(boolean throwException) {
        switch (((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType) {
            case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                ((PredictionMLTask.DeepHubPredictionMLTask)this.task).preprocessing.target_remapping = this.guessObjectDetectionRemapping(this.table, ((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable, throwException);
                break;
            }
            case DEEP_HUB_IMAGE_CLASSIFICATION: {
                ((PredictionMLTask.DeepHubPredictionMLTask)this.task).preprocessing.target_remapping = this.guessClassificationTargetRemapping(this.table, ((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable, throwException);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)((PredictionMLTask.DeepHubPredictionMLTask)this.task).predictionType));
            }
        }
    }

    private List<PredictionPreprocessingParams.MappingValue> guessObjectDetectionRemapping(MemTable table, String targetVariable, boolean throwException) {
        BoundingBoxFaceter faceter = new BoundingBoxFaceter(targetVariable);
        FilterFacet targetCategories = faceter.observeAndCompute(table);
        this.rowsParsingFailuresPercent = Math.round(100.0 * (double)faceter.getNumSkippedRows() / (double)table.nrows());
        if (throwException && this.rowsParsingFailuresPercent > 50L) {
            String message = this.rowsParsingFailuresPercent + "% rows from target column ('" + table.column(((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable).getName() + "') could not be parsed as an object detection target, did you select the right column ?";
            throw new IllegalArgumentException(message);
        }
        ArrayList<PredictionPreprocessingParams.MappingValue> targetRemapping = new ArrayList<PredictionPreprocessingParams.MappingValue>();
        for (int i = 0; i < targetCategories.values.size(); ++i) {
            FilterFacet.Val targetCategory = targetCategories.values.get(i);
            targetRemapping.add(new PredictionPreprocessingParams.MappingValue(targetCategory.label, i, (int)targetCategory.count));
        }
        return targetRemapping;
    }

    @Override
    public FeaturePreprocessingParams guessSingleFeature(MemColumn column) {
        FeaturePreprocessingParams fp = FeatureGuessUtils.guessSingleFeature(this.table, column, this.task);
        fp.role = FeaturePreprocessingParams.Role.PROFILING;
        MemColumn targetColumn = this.table.column(((PredictionMLTask.DeepHubPredictionMLTask)this.task).targetVariable);
        if (column == targetColumn) {
            fp.role = FeaturePreprocessingParams.Role.TARGET;
        }
        if (((PredictionMLTask.DeepHubPredictionMLTask)this.task).pathColumn != null && column == this.table.column(((PredictionMLTask.DeepHubPredictionMLTask)this.task).pathColumn)) {
            fp.role = FeaturePreprocessingParams.Role.INPUT;
        }
        return fp;
    }

    static class PathStats {
        public Column column;
        public int failedCount;
        public int pathCount;
        private static final Set<String> KNOWN_IMAGES_EXTENSIONS = new HashSet<String>();

        PathStats(MemTable table, MemColumn column, Set<String> supportedExtensions) {
            this.column = column;
            for (int i = 0; i < table.nrows(); ++i) {
                MemRow row = table.rows.get(i);
                String value = row.get(column);
                if (value != null && this.isSupportedExtension(value, supportedExtensions)) {
                    ++this.pathCount;
                    continue;
                }
                ++this.failedCount;
            }
        }

        private boolean isSupportedExtension(String s, Set<String> supportedExtensions) {
            for (String ext : supportedExtensions) {
                if (!s.toLowerCase().endsWith(ext)) continue;
                return true;
            }
            return false;
        }

        static {
            KNOWN_IMAGES_EXTENSIONS.add("jpg");
            KNOWN_IMAGES_EXTENSIONS.add("jpeg");
            KNOWN_IMAGES_EXTENSIONS.add("png");
        }
    }
}

