/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.docgen.extractor;

import com.dataiku.dip.analysis.docgen.extractor.ModelExtractor;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.clustering.ClusteringModelDetails;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.TabularPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.datasets.DatasetSelection;
import com.dataiku.dip.datasets.SamplingParam;
import com.jayway.jsonpath.DocumentContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class TrainingAndTestingStrategyExtractor
implements ModelExtractor<List<List<String>>> {
    private String getTrainTestPolicyName(SplitParams.TrainTestPolicy element) {
        for (TrainTestPolicyName policyName : TrainTestPolicyName.values()) {
            if (!policyName.name().equals(element.name())) continue;
            return policyName.getUserFriendlyName();
        }
        return element.toString();
    }

    private String getSamplingMethodName(SamplingParam.SamplingMethod element) {
        for (SamplingMethodName methodName : SamplingMethodName.values()) {
            if (!methodName.name().equals(element.name())) continue;
            return methodName.getUserFriendlyName();
        }
        return element.toString();
    }

    @Override
    public List<List<String>> extract(DocumentContext documentContext, ModelDetailsBase model) {
        if (model instanceof TabularPredictionModelDetails && model.splitDesc != null && model.splitDesc.params != null && model.splitDesc.params.ttPolicy != null) {
            return this.extractPrediction((TabularPredictionModelDetails)model);
        }
        if (model instanceof ClusteringModelDetails && model.splitDesc != null && model.splitDesc.cparams != null && model.splitDesc.cparams.selection != null) {
            ArrayList<List<String>> features = new ArrayList<List<String>>();
            return this.extractSingleDataset(features, model.splitDesc.cparams.selection);
        }
        return Collections.singletonList(Collections.singletonList(""));
    }

    public List<List<String>> extractPrediction(TabularPredictionModelDetails model) {
        List<List<String>> features = new ArrayList<List<String>>();
        SplitParams.TrainTestPolicy policy = model.splitDesc.params.ttPolicy;
        features.add(Arrays.asList("Policy", this.getTrainTestPolicyName(policy)));
        String dataset = model.splitDesc.params.ssdDatasetSmartName;
        if (dataset != null) {
            features.add(Arrays.asList("Sampling method", dataset));
        }
        boolean useTimeOrdering = SplitParams.SplitMode.SORTED.equals((Object)model.splitDesc.params.ssdSplitMode);
        if (!model.splitDesc.params.streamAll) {
            if (useTimeOrdering) {
                features.add(Arrays.asList("Use time ordering", "Yes"));
                ResolvedPredictionCoreParams coreParams = model.getCoreParams();
                if (coreParams instanceof ResolvedClassicalPredictionCoreParams) {
                    String timeVariable = ((ResolvedClassicalPredictionCoreParams)coreParams).time.timeVariable;
                    features.add(Arrays.asList("Time variable", timeVariable));
                    features.add(Arrays.asList("Validate on larger values", ((ResolvedClassicalPredictionCoreParams)coreParams).time.ascending ? "Yes" : "No"));
                }
            } else {
                features.add(Arrays.asList("Use time ordering", "No"));
            }
        }
        if (SplitParams.TrainTestPolicy.SPLIT_SINGLE_DATASET.equals((Object)policy) && model.splitDesc != null && model.splitDesc.params != null) {
            if (model.splitDesc.params.ssdSelection != null) {
                features = this.extractSingleDataset(features, model.splitDesc.params.ssdSelection);
            }
            if (!model.splitDesc.params.streamAll) {
                if (useTimeOrdering) {
                    features.add(Arrays.asList("Split mode", "Based on time variable"));
                } else {
                    String splitMode = model.splitDesc.params.ssdSplitMode.name();
                    features.add(Arrays.asList("Split mode", splitMode.substring(0, 1).toUpperCase() + splitMode.substring(1).toLowerCase()));
                }
            }
            if (model instanceof TimeseriesForecastingModelDetails) {
                if (model.splitDesc.params.kfold) {
                    features.add(Arrays.asList("Use K-fold cross-testing", "Yes"));
                    features.add(Arrays.asList("Number of folds", Integer.toString(model.splitDesc.params.nFolds)));
                    features.add(Arrays.asList("Fold offset", model.modeling.grid_search_params.foldOffset ? "Yes" : "No"));
                    features.add(Arrays.asList("Equal duration train set folds", model.modeling.grid_search_params.equalDurationFolds ? "Yes" : "No"));
                } else {
                    features.add(Arrays.asList("Use K-fold cross-testing", "No"));
                }
            } else {
                if (model.splitDesc.params.kfold) {
                    features.add(Arrays.asList("Use K-fold cross-testing", "Yes"));
                    features.add(Arrays.asList("Number of folds", Integer.toString(model.splitDesc.params.nFolds)));
                    if (this.isBinaryOrMultiClassification(model)) {
                        features.add(Arrays.asList("Stratified", model.splitDesc.params.ssdStratified ? "Yes" : "No"));
                    }
                    features.add(Arrays.asList("Grouped", model.splitDesc.params.ssdGrouped ? "Yes" : "No"));
                    if (model.splitDesc.params.ssdGrouped) {
                        features.add(Arrays.asList("Group column", model.splitDesc.params.ssdGroupColumnName));
                    }
                } else {
                    features.add(Arrays.asList("Use K-fold cross-testing", "No"));
                    features.add(Arrays.asList("Train ratio", Double.toString(model.splitDesc.params.ssdTrainingRatio)));
                }
                if (!useTimeOrdering) {
                    features.add(Arrays.asList("Random seed", Long.toString(model.splitDesc.params.ssdSeed)));
                }
            }
        }
        return features;
    }

    public List<List<String>> extractSingleDataset(List<List<String>> features, DatasetSelection selection) {
        SamplingParam.SamplingMethod samplingMethod = selection.samplingMethod;
        features.add(Arrays.asList("Sampling method", this.getSamplingMethodName(samplingMethod)));
        switch (samplingMethod) {
            case FULL: {
                break;
            }
            case HEAD_SEQUENTIAL: 
            case TAIL_SEQUENTIAL: {
                features.add(Arrays.asList("Record limit", Long.toString(selection.maxRecords)));
                break;
            }
            case RANDOM_FIXED_NB: 
            case RANDOM_FIXED_NB_EXACT: 
            case STRATIFIED_TARGET_RATIO_EXACT: {
                features.add(Arrays.asList("Record limit", Long.toString(selection.maxRecords)));
                if (selection.seed == null) break;
                features.add(Arrays.asList("Random seed", Long.toString(selection.seed)));
                break;
            }
            case RANDOM_FIXED_RATIO: 
            case RANDOM_FIXED_RATIO_EXACT: 
            case STRATIFIED_TARGET_NB_EXACT: {
                features.add(Arrays.asList("% to use", Double.toString(selection.targetRatio)));
                if (selection.seed == null) break;
                features.add(Arrays.asList("Random seed", Long.toString(selection.seed)));
                break;
            }
            case COLUMN_BASED: {
                features.add(Arrays.asList("Record limit", Long.toString(selection.maxRecords)));
                features.add(Arrays.asList("Column", selection.column));
                break;
            }
            case CLASS_REBALANCE_TARGET_NB_APPROX: {
                features.add(Arrays.asList("Record limit", Long.toString(selection.maxRecords)));
                features.add(Arrays.asList("Column", selection.column));
                if (selection.seed == null) break;
                features.add(Arrays.asList("Random seed", Long.toString(selection.seed)));
                break;
            }
            case CLASS_REBALANCE_TARGET_RATIO_APPROX: {
                features.add(Arrays.asList("% to use", Double.toString(selection.targetRatio)));
                features.add(Arrays.asList("Column", selection.column));
                if (selection.seed == null) break;
                features.add(Arrays.asList("Random seed", Long.toString(selection.seed)));
                break;
            }
            case COLUMN_ORDERED: {
                features.add(Arrays.asList("Record limit", Long.toString(selection.maxRecords)));
                features.add(Arrays.asList("Sorted by", selection.column));
                features.add(Arrays.asList("Ascending order", selection.ascending ? "Yes" : "No"));
                break;
            }
        }
        return features;
    }

    private boolean isBinaryOrMultiClassification(TabularPredictionModelDetails predictionModelDetails) {
        ResolvedPredictionCoreParams coreParams = predictionModelDetails.getCoreParams();
        if (coreParams == null) {
            return false;
        }
        switch (coreParams.prediction_type) {
            case BINARY_CLASSIFICATION: 
            case CAUSAL_BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                return true;
            }
        }
        return false;
    }

    public static enum TrainTestPolicyName {
        SPLIT_SINGLE_DATASET("Split the dataset"),
        EXPLICIT_FILTERING_SINGLE_DATASET("Explicit extracts from the dataset"),
        EXPLICIT_FILTERING_TWO_DATASETS("Explicit extracts from two datasets");

        private String userFriendlyName;

        private TrainTestPolicyName(String userFriendlyName) {
            this.userFriendlyName = userFriendlyName;
        }

        String getUserFriendlyName() {
            return this.userFriendlyName;
        }
    }

    public static enum SamplingMethodName {
        FULL("No sampling (whole data)"),
        HEAD_SEQUENTIAL("First records"),
        TAIL_SEQUENTIAL("Last records"),
        RANDOM_FIXED_NB("Random (approx. nb. records)"),
        RANDOM_FIXED_RATIO("Random (approx. ratio)"),
        COLUMN_BASED("Column values subset (approx. nb. records)"),
        STRATIFIED_TARGET_NB_EXACT("Stratified (nb. records)"),
        STRATIFIED_TARGET_RATIO_EXACT("Stratified (ratio)"),
        CLASS_REBALANCE_TARGET_NB_APPROX("Class rebalance (approx. nb. records)"),
        CLASS_REBALANCE_TARGET_RATIO_APPROX("Class rebalance (approx. ratio)"),
        RANDOM_FIXED_NB_EXACT("Random (approx. nb. records)"),
        RANDOM_FIXED_RATIO_EXACT("Random (target ratio of data)"),
        COLUMN_ORDERED("First records sorted by a column");

        private String userFriendlyName;

        private SamplingMethodName(String userFriendlyName) {
            this.userFriendlyName = userFriendlyName;
        }

        public String getUserFriendlyName() {
            return this.userFriendlyName;
        }
    }
}

