/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.docgen.extractor;

import com.dataiku.dip.analysis.docgen.extractor.CrossValidationStrategyExtractor;
import com.dataiku.dip.analysis.docgen.extractor.ModelExtractor;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.TabularPredictionModelDetails;
import com.jayway.jsonpath.DocumentContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class HyperparameterCrossValidationStrategyTableExtractor
implements ModelExtractor<List<List<String>>> {
    @Override
    public List<List<String>> extract(DocumentContext documentContext, ModelDetailsBase model) throws IOException {
        if (!(model instanceof TabularPredictionModelDetails)) {
            return Collections.singletonList(Collections.singletonList("No hyperparameter cross validation strategy."));
        }
        TabularPredictionModelDetails predictionModelDetails = (TabularPredictionModelDetails)model;
        if (predictionModelDetails.modeling == null || predictionModelDetails.modeling.grid_search_params == null) {
            return Collections.singletonList(Collections.singletonList("No hyperparameter cross validation strategy."));
        }
        PredictionModelingParams.GridSearchParams params = predictionModelDetails.modeling.grid_search_params;
        ArrayList<List<String>> output = new ArrayList<List<String>>();
        switch (params.strategy) {
            case GRID: {
                output.add(Arrays.asList("Search strategy", ""));
                output.add(Arrays.asList("Strategy", "Grid search"));
                output.add(Arrays.asList("Search parameters", ""));
                output.add(Arrays.asList("Randomize grid search", params.randomized ? "Yes" : "No"));
                if (params.randomized) {
                    output.add(Arrays.asList("Random state (hyperparameter search)", Integer.toString(params.seed)));
                } else {
                    output.add(Arrays.asList("Random state (hyperparameter search)", params.seed + " (Unused)"));
                }
                if (params.nIter == 0) {
                    output.add(Arrays.asList("Max number of iterations", "0 (no limit)"));
                    break;
                }
                output.add(Arrays.asList("Max number of iterations", Integer.toString(params.nIter)));
                break;
            }
            case BAYESIAN: {
                output.add(Arrays.asList("Search strategy", ""));
                output.add(Arrays.asList("Strategy", "Bayesian search"));
                output.add(Arrays.asList("Search parameters", ""));
                output.add(Arrays.asList("Random state (hyperparameter search)", Integer.toString(params.seed)));
                if (params.nIterRandom == 0) {
                    output.add(Arrays.asList("Max number of iterations", "0 (no limit)"));
                    break;
                }
                output.add(Arrays.asList("Max number of iterations", Integer.toString(params.nIterRandom)));
                break;
            }
            case RANDOM: {
                output.add(Arrays.asList("Search strategy", ""));
                output.add(Arrays.asList("Strategy", "Random search"));
                output.add(Arrays.asList("Search parameters", ""));
                output.add(Arrays.asList("Random state (hyperparameter search)", Integer.toString(params.seed)));
                if (params.nIterRandom == 0) {
                    output.add(Arrays.asList("Max number of iterations", "0 (no limit)"));
                    break;
                }
                output.add(Arrays.asList("Max number of iterations", Integer.toString(params.nIterRandom)));
            }
        }
        if (params.timeout == 0) {
            output.add(Arrays.asList("Max search time", "0 (no limit)"));
        } else {
            output.add(Arrays.asList("Max search time", Integer.toString(params.timeout)));
        }
        output.add(Arrays.asList("Parallelism", Integer.toString(params.nJobs)));
        output.add(Arrays.asList("Cross-validation", ""));
        output.add(Arrays.asList("Cross-validation strategy", new CrossValidationStrategyExtractor().extract(documentContext, model)));
        switch (params.mode) {
            case TIME_SERIES_SINGLE_SPLIT: {
                output.add(Arrays.asList("Split ratio", Float.toString(params.splitRatio)));
                break;
            }
            case SHUFFLE: {
                output.add(Arrays.asList("Split ratio", Float.toString(params.splitRatio)));
                output.add(Arrays.asList("Random state (cross-validation split)", Integer.toString(params.cvSeed)));
                if (!this.isBinaryOrMultiClassification(predictionModelDetails)) break;
                output.add(Arrays.asList("Stratified", params.stratified ? "Yes" : "No"));
                break;
            }
            case TIME_SERIES_KFOLD: {
                output.add(Arrays.asList("Number of folds", Integer.toString(params.nFolds)));
                output.add(Arrays.asList("Fold offset", params.foldOffset ? "Yes" : "No"));
                output.add(Arrays.asList("Equal duration train set folds", params.equalDurationFolds ? "Yes" : "No"));
                break;
            }
            case KFOLD: {
                output.add(Arrays.asList("Number of folds", Integer.toString(params.nFolds)));
                output.add(Arrays.asList("Random state (cross-validation split)", Integer.toString(params.cvSeed)));
                if (this.isBinaryOrMultiClassification(predictionModelDetails)) {
                    output.add(Arrays.asList("Stratified", params.stratified ? "Yes" : "No"));
                }
                output.add(Arrays.asList("Grouped", params.grouped ? "Yes" : "No"));
                if (!params.grouped) break;
                output.add(Arrays.asList("Group column", params.groupColumnName));
                break;
            }
            case CUSTOM: {
                output.add(Arrays.asList("Code", params.code));
            }
        }
        return output;
    }

    private boolean isBinaryOrMultiClassification(TabularPredictionModelDetails predictionModelDetails) {
        ResolvedPredictionCoreParams coreParams = predictionModelDetails.getCoreParams();
        if (coreParams == null || coreParams.prediction_type == null) {
            return false;
        }
        switch (coreParams.prediction_type) {
            case BINARY_CLASSIFICATION: 
            case CAUSAL_BINARY_CLASSIFICATION: 
            case MULTICLASS: {
                return true;
            }
        }
        return false;
    }
}

