/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.shared;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.preprocessing.CatFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.NumFeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TextFeaturePreprocessingParams;
import com.dataiku.dip.analysis.stats.CategoryVariableBasicStats;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.shaker.analysis.NumericalVariableAnalysis;
import com.dataiku.dip.shaker.analysis.NumericalVariableAnalyzer;
import com.dataiku.dip.shaker.types.Date;
import com.dataiku.dip.shaker.types.DateOnly;
import com.dataiku.dip.shaker.types.DatetimeNoTz;
import com.dataiku.dip.shaker.types.FreeText;
import com.dataiku.scoring.pipelines.DatetimeCyclicalEncoder;
import java.util.HashSet;
import java.util.Locale;
import java.util.Set;
import org.apache.log4j.Logger;

public class FeatureGuessUtils {
    public static boolean isTemporal(MemColumn column) {
        return column != null && column.selectedType != null && (column.selectedType.type instanceof Date || column.selectedType.type instanceof DateOnly || column.selectedType.type instanceof DatetimeNoTz);
    }

    public static boolean isNumerical(MemColumn column) {
        return column.selectedType.nbFails == 0 && column.selectedType.type.isDouble();
    }

    public static boolean isText(MemColumn column) {
        return column.selectedType.type.getClass() == FreeText.class;
    }

    public static FeaturePreprocessingParams guessSingleFeature(MemTable table, MemColumn column, MLTask task) {
        boolean isNumerical = FeatureGuessUtils.isNumerical(column);
        boolean isText = FeatureGuessUtils.isText(column);
        assert (!isNumerical || !isText);
        if (isNumerical) {
            return FeatureGuessUtils.standardNumFeatureGuess(table, column, task.backendType);
        }
        if (isText) {
            return FeatureGuessUtils.standardTextFeatureGuess(column, task.backendType);
        }
        return FeatureGuessUtils.standardNoSparseCatFeatureGuess(table, column, task);
    }

    public static CategoryVariableBasicStats buildStats(MemTable table, MemColumn col) {
        CategoryVariableBasicStats cvbs = new CategoryVariableBasicStats();
        HashSet<String> set = new HashSet<String>();
        for (int i = 0; i < table.rows.size(); ++i) {
            String v = table.rows.get(i).get(col);
            if (v == null) {
                ++cvbs.nbMissing;
                continue;
            }
            set.add(v);
        }
        cvbs.cardinality = set.size();
        return cvbs;
    }

    public static boolean isID(MemColumn column, int nrows, long cardinality) {
        if (column.selectedType.type.getMeaningId().equals("DoubleMeaning")) {
            return false;
        }
        String columnName = column.getName();
        if ((double)cardinality >= 0.95 * (double)nrows && (columnName.toLowerCase(Locale.ENGLISH).startsWith("id") || columnName.toLowerCase(Locale.ENGLISH).endsWith("id"))) {
            return true;
        }
        return cardinality == (long)nrows;
    }

    public static NumFeaturePreprocessingParams standardNumFeatureGuess(MemTable table, MemColumn column, MLTask.BackendType backendType) {
        NumericalVariableAnalyzer analyzer = new NumericalVariableAnalyzer(15);
        analyzer.analyse(table, column.getName(), null);
        analyzer.enableTimePeriodAnalysis = FeatureGuessUtils.isTemporal(column);
        analyzer.compute();
        NumericalVariableAnalysis analysis = analyzer.getOut();
        if (analysis == null) {
            return NumFeaturePreprocessingParams.buildReject(FeaturePreprocessingParams.FeatureHandlingReason.REJECT_MISSING);
        }
        if (analysis.cardinality == 1L) {
            if (analysis.mode == 1.0 && analysis.missing > 0.0) {
                return NumFeaturePreprocessingParams.buildBooleanHandling();
            }
            return NumFeaturePreprocessingParams.buildReject(FeaturePreprocessingParams.FeatureHandlingReason.REJECT_ZERO_VARIANCE);
        }
        if (analysis.missing + analysis.bad >= 0.95) {
            return NumFeaturePreprocessingParams.buildReject(FeaturePreprocessingParams.FeatureHandlingReason.REJECT_MISSING);
        }
        if (FeatureGuessUtils.isID(column, table.nrows(), analysis.cardinality)) {
            return NumFeaturePreprocessingParams.buildReject(FeaturePreprocessingParams.FeatureHandlingReason.REJECT_IDENTIFIER);
        }
        Set<DatetimeCyclicalEncoder.Period> relevantPeriods = analysis.relevantPeriods;
        if (analyzer.enableTimePeriodAnalysis && !relevantPeriods.isEmpty() && backendType.isPythonBased()) {
            NumFeaturePreprocessingParams ret = NumFeaturePreprocessingParams.buildDateHandling();
            ret.datetime_cyclical_periods = relevantPeriods;
            return ret;
        }
        return NumFeaturePreprocessingParams.buildStandardHandling();
    }

    private static CatFeaturePreprocessingParams standardSparsityChecks(MemTable table, MemColumn column, CategoryVariableBasicStats cvbs) {
        CatFeaturePreprocessingParams ret = CatFeaturePreprocessingParams.buildStdInput();
        if (cvbs.cardinality == 1) {
            ret.role = FeaturePreprocessingParams.Role.REJECT;
            ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_ZERO_VARIANCE;
        } else if ((double)cvbs.nbMissing >= 0.95 * (double)table.nrows()) {
            ret.role = FeaturePreprocessingParams.Role.REJECT;
            ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_MISSING;
        } else if ((double)cvbs.cardinality >= 0.95 * (double)table.nrows() && (column.getName().toLowerCase(Locale.ENGLISH).startsWith("id") || column.getName().toLowerCase(Locale.ENGLISH).endsWith("id"))) {
            ret.role = FeaturePreprocessingParams.Role.REJECT;
            ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_IDENTIFIER;
        } else if (cvbs.cardinality == table.nrows()) {
            ret.role = FeaturePreprocessingParams.Role.REJECT;
            ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_IDENTIFIER;
        }
        return ret;
    }

    public static CatFeaturePreprocessingParams standardNoSparseCatFeatureGuess(MemTable table, MemColumn column, MLTask task) {
        return new StandardCategoricalGuesser().guess(table, column, task);
    }

    public static TextFeaturePreprocessingParams standardTextFeatureGuess(MemColumn column, MLTask.BackendType backendType) {
        TextFeaturePreprocessingParams ret = new TextFeaturePreprocessingParams();
        ret.name = column.getName();
        if (backendType == MLTask.BackendType.KERAS) {
            ret.text_handling = TextFeaturePreprocessingParams.TextHandlingMethod.CUSTOM;
            ret.role = FeaturePreprocessingParams.Role.INPUT;
            ret.sendToInput = ret.name + "_preprocessed";
            ret.customHandlingCode = "from dataiku.doctor.deep_learning.preprocessing import TokenizerProcessor\n\n# Defines a processor that tokenizes a text. It computes a vocabulary on all the corpus.\n# Then, each text is converted to a vector representing the sequence of words, where each \n# element represents the index of the corresponding word in the vocabulary. The result is \n# padded with 0 up to the `max_len` in order for all the vectors to have the same length.\n\n#   num_words  - maximum number of words in the vocabulary\n#   max_len    - length of each sequence. If the text is longer,\n#                it will be truncated, and if it is shorter, it will be padded\n#                with 0.\nprocessor = TokenizerProcessor(num_words=10000, max_len=32)";
        } else {
            ret.text_handling = backendType.isSparkBased() ? TextFeaturePreprocessingParams.TextHandlingMethod.TOKENIZE_HASHING : TextFeaturePreprocessingParams.TextHandlingMethod.TOKENIZE_HASHING_SVD;
            ret.role = FeaturePreprocessingParams.Role.REJECT;
            ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_DEFAULT_TEXT_HANDLING;
        }
        return ret;
    }

    public static class StandardCategoricalGuesser
    implements FeatureGuesser<CatFeaturePreprocessingParams> {
        @Override
        public CatFeaturePreprocessingParams guess(MemTable table, MemColumn column, MLTask task) {
            Logger.getLogger(this.getClass()).info((Object)"ENTERING STANDARD GUESSER");
            CategoryVariableBasicStats cvbs = FeatureGuessUtils.buildStats(table, column);
            CatFeaturePreprocessingParams ret = FeatureGuessUtils.standardSparsityChecks(table, column, cvbs);
            if (ret.role == FeaturePreprocessingParams.Role.REJECT) {
                return ret;
            }
            if (task.taskType == MLTask.MLTaskType.CLUSTERING && cvbs.cardinality > 100) {
                ret.role = FeaturePreprocessingParams.Role.REJECT;
                ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_CARDINALITY;
            } else {
                ret.role = FeaturePreprocessingParams.Role.INPUT;
                if (task.backendType == MLTask.BackendType.H2O) {
                    ret.category_handling = CatFeaturePreprocessingParams.CategoryHandlingMethod.NONE;
                } else {
                    ret.category_handling = CatFeaturePreprocessingParams.CategoryHandlingMethod.DUMMIFY;
                    ret.dummy_drop = CatFeaturePreprocessingParams.DummyDroppingMethod.AUTO;
                    ret.dummy_clip = CatFeaturePreprocessingParams.DummyClippingMethod.MAX_NB_CATEGORIES;
                    ret.max_nb_categories = 100;
                }
            }
            return ret;
        }
    }

    public static class SparsityLimitGuesser
    implements FeatureGuesser<CatFeaturePreprocessingParams> {
        private final int maxToDummify;
        private final int maxToDrop;

        public SparsityLimitGuesser(int maxToDummify, int maxToDrop) {
            this.maxToDummify = maxToDummify;
            this.maxToDrop = maxToDrop;
        }

        @Override
        public CatFeaturePreprocessingParams guess(MemTable table, MemColumn column, MLTask task) {
            Logger.getLogger(this.getClass()).info((Object)"ENTERING SPARSITY GUESSER");
            CategoryVariableBasicStats cvbs = FeatureGuessUtils.buildStats(table, column);
            CatFeaturePreprocessingParams ret = FeatureGuessUtils.standardSparsityChecks(table, column, cvbs);
            if (ret.role == FeaturePreprocessingParams.Role.REJECT) {
                return ret;
            }
            Logger.getLogger(FeatureGuessUtils.class).info((Object)("found feature with : " + cvbs.cardinality + " max : " + this.maxToDummify + " , " + this.maxToDrop));
            if (cvbs.cardinality > this.maxToDrop) {
                ret.role = FeaturePreprocessingParams.Role.REJECT;
                ret.autoReason = FeaturePreprocessingParams.FeatureHandlingReason.REJECT_CARDINALITY;
            } else {
                ret.category_handling = CatFeaturePreprocessingParams.CategoryHandlingMethod.DUMMIFY;
                ret.dummy_drop = CatFeaturePreprocessingParams.DummyDroppingMethod.AUTO;
                ret.dummy_clip = CatFeaturePreprocessingParams.DummyClippingMethod.MAX_NB_CATEGORIES;
                ret.max_nb_categories = this.maxToDummify;
            }
            return ret;
        }
    }

    public static interface FeatureGuesser<T extends FeaturePreprocessingParams> {
        public T guess(MemTable var1, MemColumn var2, MLTask var3);
    }
}

