/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.i18n.TranslationService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.warnings.WarningsContext;
import com.dataiku.dss.shadelib.com.google.common.collect.ImmutableMap;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class MLDiagnostics {
    public static final String ML_DIAGNOSTICS_FILENAME = "ml_diagnostics.json";
    private static final DiagnosticDefinition DATASET_SANITY_CHECKS = new DiagnosticDefinition("Dataset sanity checks", "ANALYSIS.ML.MLDiagnostics.DATASET_SANITY_CHECKS.DESCRIPTION", "Checks on train and test datasets to ensure reliable performance estimation", "dataset-sanity-checks");
    private static final DiagnosticDefinition MODELING_PARAMETERS = new DiagnosticDefinition("Modeling parameters", "ANALYSIS.ML.MLDiagnostics.MODELING_PARAMETERS.DESCRIPTION", "Checks on modeling parameters with respect to the characteristics of the data", "modeling-parameters");
    private static final DiagnosticDefinition TRAINING_SPEED = new DiagnosticDefinition("Training speed", "ANALYSIS.ML.MLDiagnostics.RUNTIME.DESCRIPTION", "Check to ensure the model training speed is optimal", "training-speed");
    private static final DiagnosticDefinition OVERFIT_DETECTION = new DiagnosticDefinition("Overfit detection", "ANALYSIS.ML.MLDiagnostics.TRAINING_OVERFIT.DESCRIPTION", "Checks on trained model attributes to identify potential overfitting", "overfitting-detection");
    private static final DiagnosticDefinition LEAKAGE_DETECTION = new DiagnosticDefinition("Leakage detection", "ANALYSIS.ML.MLDiagnostics.LEAKAGE_DETECTION.DESCRIPTION", "Checks on performance metrics and feature importances to detect possible data leakage", "leakage-detection");
    private static final DiagnosticDefinition MODEL_CHECK = new DiagnosticDefinition("Model check", "ANALYSIS.ML.MLDiagnostics.MODEL_CHECK.DESCRIPTION", "Check to ensure the model outperforms a baseline dummy classifier model", "model-checks");
    private static final DiagnosticDefinition ML_ASSERTIONS = new DiagnosticDefinition("ML assertions", "ANALYSIS.ML.MLDiagnostics.ML_ASSERTIONS.DESCRIPTION", "Check to ensure the model satisfies the ml assertions", "ml-assertions");
    private static final DiagnosticDefinition ABNORMAL_PREDICTIONS_DETECTION = new DiagnosticDefinition("Abnormal predictions detection", "ANALYSIS.ML.MLDiagnostics.ABNORMAL_PREDICTIONS_DETECTION.DESCRIPTION", "Check to ensure that the model doesn't always predict the same class", "abnormal-predictions-detection");
    private static final DiagnosticDefinition TRAINING_REPRODUCIBILITY = new DiagnosticDefinition("Training reproducibility", "ANALYSIS.ML.MLDiagnostics.TRAINING_REPRODUCIBILITY.DESCRIPTION", "Check to ensure the model training is reproducible", "training-reproducibility");
    private static final DiagnosticDefinition SCORING_DATASET_SANITY_CHECKS = new DiagnosticDefinition("Scoring dataset sanity checks", "ANALYSIS.ML.MLDiagnostics.SCORING_DATASET_SANITY_CHECKS.DESCRIPTION", "Checks on the scoring dataset to ensure reliable scoring", "scoring-dataset-sanity-checks");
    private static final DiagnosticDefinition EVALUATION_DATASET_SANITY_CHECKS = new DiagnosticDefinition("Evaluation dataset sanity checks", "ANALYSIS.ML.MLDiagnostics.EVALUATION_DATASET_SANITY_CHECKS.DESCRIPTION", "Checks on the evaluation dataset to ensure reliable evaluation", "evaluation-dataset-sanity-checks");
    private static final DiagnosticDefinition TIME_SERIES_RESAMPLING_CHECKS = new DiagnosticDefinition("Time series resampling checks", "ANALYSIS.ML.MLDiagnostics.TIME_SERIES_RESAMPLING_CHECKS.DESCRIPTION", "Checks on the time series to ensure they can be resampled", "time-series-resampling-checks");
    private static final DiagnosticDefinition CAUSAL_TREATMENT_CHECKS = new DiagnosticDefinition("Treatment checks", "ANALYSIS.ML.MLDiagnostics.CAUSAL_TREATMENT_CHECKS", "Checks the distribution of the treatment (randomness, positivity) by analysing a propensity model", "causal-prediction-treatment-checks");
    private static final DiagnosticDefinition CAUSAL_PROPENSITY_CHECKS = new DiagnosticDefinition("Propensity checks", "ANALYSIS.ML.MLDiagnostics.CAUSAL_PROPENSITY_CHECKS", "Checks the calibration of the propensity model", "causal-prediction-propensity-model-checks");
    private static final DiagnosticDefinition LLM_EVALUATION_COMPUTATION_CHECKS = new DiagnosticDefinition("Evaluation error", "ANALYSIS.ML.MLDiagnostics.LLM_EVALUATION_COMPUTATION_ERROR", "Checks the full execution of the LLM Evaluation recipe", "llm-evaluation-computation-checks");
    public List<DoctorDiagnostic> diagnostics;
    public static final Map<WarningsContext.WarningType, DiagnosticDefinition> DIAGNOSTICS_TYPES = ImmutableMap.of((Object)WarningsContext.WarningType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, (Object)DATASET_SANITY_CHECKS, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_MODELING_PARAMETERS, (Object)MODELING_PARAMETERS, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_RUNTIME, (Object)TRAINING_SPEED, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_TRAINING_OVERFIT, (Object)OVERFIT_DETECTION, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_LEAKAGE_DETECTION, (Object)LEAKAGE_DETECTION, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_MODEL_CHECK, (Object)MODEL_CHECK, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_ML_ASSERTIONS, (Object)ML_ASSERTIONS, (Object)WarningsContext.WarningType.ML_DIAGNOSTICS_ABNORMAL_PREDICTIONS_DETECTION, (Object)ABNORMAL_PREDICTIONS_DETECTION);

    public static Map<WarningsContext.WarningType, DiagnosticDefinition> translateDiagnosticTypes(Map<WarningsContext.WarningType, DiagnosticDefinition> diag, String lang) {
        if (lang == null) {
            return diag;
        }
        LinkedHashMap<WarningsContext.WarningType, DiagnosticDefinition> ret = new LinkedHashMap<WarningsContext.WarningType, DiagnosticDefinition>();
        TranslationService sts = (TranslationService)SpringUtils.getBean(TranslationService.class);
        for (Map.Entry<WarningsContext.WarningType, DiagnosticDefinition> e : diag.entrySet()) {
            DiagnosticDefinition def = (DiagnosticDefinition)JSON.deepCopy((Object)e.getValue());
            def.description = sts.translateNoContext(lang, def.translationId, def.description, new Object[0]);
            ret.put(e.getKey(), def);
        }
        return ret;
    }

    public static Map<WarningsContext.WarningType, DiagnosticDefinition> getMLTaskDiagnosticsTypes(MLTask.MLTaskType mlTaskType, MLTask.BackendType backendType) {
        return MLDiagnostics.getMLTaskDiagnosticsTypes(new MLTaskConfig(mlTaskType, backendType));
    }

    public static Map<WarningsContext.WarningType, DiagnosticDefinition> getMLTaskDiagnosticsTypes(MLTask.MLTaskType mlTaskType, MLTask.BackendType backendType, PredictionMLTask.PredictionType predictionType) {
        return MLDiagnostics.getMLTaskDiagnosticsTypes(new MLTaskConfig(mlTaskType, backendType, predictionType));
    }

    private static Map<WarningsContext.WarningType, DiagnosticDefinition> getMLTaskDiagnosticsTypes(MLTaskConfig mlTaskConfig) {
        LinkedHashMap<WarningsContext.WarningType, DiagnosticDefinition> diagnosticsTypes = new LinkedHashMap<WarningsContext.WarningType, DiagnosticDefinition>();
        for (DiagnosticsTypes diagnosticsType : DiagnosticsTypes.values()) {
            if (!diagnosticsType.supportedMLTaskConfigs.contains(mlTaskConfig)) continue;
            diagnosticsTypes.put(diagnosticsType.warningType, diagnosticsType.diagnosticDefinition);
        }
        return diagnosticsTypes;
    }

    public static List<DoctorDiagnostic> getDiagnosticsSafe(MLDiagnostics mlDiagnostics) {
        if (mlDiagnostics != null) {
            return mlDiagnostics.diagnostics;
        }
        return Collections.emptyList();
    }

    public static Map<WarningsContext.WarningType, Integer> countDiagnostics(FullModelId fmi) throws IOException {
        HashMap<WarningsContext.WarningType, Integer> counts = new HashMap<WarningsContext.WarningType, Integer>();
        MLDiagnostics mlDiagnostics = fmi.getMLDiagnostics();
        for (DoctorDiagnostic diag : MLDiagnostics.getDiagnosticsSafe(mlDiagnostics)) {
            if (!counts.containsKey(diag.type)) {
                counts.put(diag.type, 1);
                continue;
            }
            counts.put(diag.type, (Integer)counts.get(diag.type) + 1);
        }
        return counts;
    }

    public static void mergeIntoWarnings(ModelLikeId mlid, WarningsContext into) throws IOException {
        MLDiagnostics diagnostics = mlid.getMLDiagnostics();
        if (diagnostics != null) {
            diagnostics.mergeIntoWarnings(into);
        }
    }

    public void mergeIntoWarnings(WarningsContext into) {
        for (DoctorDiagnostic diag : this.diagnostics) {
            into.addWarning(diag.type, diag.message, null);
        }
    }

    public Map<WarningsContext.WarningType, List<String>> groupByType() {
        HashMap<WarningsContext.WarningType, List<String>> groupedDiags = new HashMap<WarningsContext.WarningType, List<String>>();
        for (DoctorDiagnostic diag : this.diagnostics) {
            if (!groupedDiags.containsKey(diag.type)) {
                groupedDiags.put(diag.type, new ArrayList());
            }
            List messages = (List)groupedDiags.get(diag.type);
            messages.add(diag.message);
        }
        return groupedDiags;
    }

    public static class DiagnosticDefinition {
        public String displayableType;
        public String translationId;
        public String description;
        public String documentationAnchor;

        public DiagnosticDefinition(String displayableType, String translationId, String description, String documentationAnchor) {
            this.displayableType = displayableType;
            this.translationId = translationId;
            this.description = description;
            this.documentationAnchor = documentationAnchor;
        }
    }

    private static class MLTaskConfig {
        private final MLTask.MLTaskType mlTaskType;
        private final MLTask.BackendType backendType;
        private final PredictionMLTask.PredictionType predictionType;

        public MLTaskConfig(MLTask.MLTaskType mlTaskType, MLTask.BackendType backendType, PredictionMLTask.PredictionType predictionType) {
            this.mlTaskType = mlTaskType;
            this.backendType = backendType;
            this.predictionType = predictionType;
        }

        public MLTaskConfig(MLTask.MLTaskType mlTaskType, MLTask.BackendType backendType) {
            this.mlTaskType = mlTaskType;
            this.backendType = backendType;
            this.predictionType = null;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            MLTaskConfig other = (MLTaskConfig)o;
            return Objects.equals((Object)this.mlTaskType, (Object)other.mlTaskType) && Objects.equals((Object)this.backendType, (Object)other.backendType) && Objects.equals((Object)this.predictionType, (Object)other.predictionType);
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.mlTaskType, this.backendType, this.predictionType});
        }
    }

    public static enum DiagnosticsTypes {
        ML_DIAGNOSTICS_DATASET_SANITY_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, DATASET_SANITY_CHECKS, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.CLUSTERING, MLTask.BackendType.PY_MEMORY), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.TIMESERIES_FORECAST), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION)),
        ML_DIAGNOSTICS_MODELING_PARAMETERS(WarningsContext.WarningType.ML_DIAGNOSTICS_MODELING_PARAMETERS, MODELING_PARAMETERS, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.CLUSTERING, MLTask.BackendType.PY_MEMORY), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.TIMESERIES_FORECAST)),
        ML_DIAGNOSTICS_RUNTIME(WarningsContext.WarningType.ML_DIAGNOSTICS_RUNTIME, TRAINING_SPEED, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.TIMESERIES_FORECAST), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.CLUSTERING, MLTask.BackendType.PY_MEMORY)),
        ML_DIAGNOSTICS_REPRODUCIBILITY(WarningsContext.WarningType.ML_DIAGNOSTICS_REPRODUCIBILITY, TRAINING_REPRODUCIBILITY, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.TIMESERIES_FORECAST)),
        ML_DIAGNOSTICS_TRAINING_OVERFIT(WarningsContext.WarningType.ML_DIAGNOSTICS_TRAINING_OVERFIT, OVERFIT_DETECTION, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION)),
        ML_DIAGNOSTICS_LEAKAGE_DETECTION(WarningsContext.WarningType.ML_DIAGNOSTICS_LEAKAGE_DETECTION, LEAKAGE_DETECTION, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION)),
        ML_DIAGNOSTICS_MODEL_CHECK(WarningsContext.WarningType.ML_DIAGNOSTICS_MODEL_CHECK, MODEL_CHECK, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION)),
        ML_DIAGNOSTICS_ML_ASSERTIONS(WarningsContext.WarningType.ML_DIAGNOSTICS_ML_ASSERTIONS, ML_ASSERTIONS, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION)),
        ML_DIAGNOSTICS_ABNORMAL_PREDICTIONS_DETECTION(WarningsContext.WarningType.ML_DIAGNOSTICS_ABNORMAL_PREDICTIONS_DETECTION, ABNORMAL_PREDICTIONS_DETECTION, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.KERAS, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.MULTICLASS), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, null, PredictionMLTask.PredictionType.REGRESSION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.DEEP_HUB, PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION)),
        ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS, SCORING_DATASET_SANITY_CHECKS, new MLTaskConfig[0]),
        ML_DIAGNOSTICS_EVALUATION_DATASET_SANITY_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_EVALUATION_DATASET_SANITY_CHECKS, EVALUATION_DATASET_SANITY_CHECKS, new MLTaskConfig[0]),
        ML_DIAGNOSTICS_TIMESERIES_RESAMPLING_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_TIMESERIES_RESAMPLING_CHECKS, TIME_SERIES_RESAMPLING_CHECKS, new MLTaskConfig[0]),
        ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, CAUSAL_TREATMENT_CHECKS, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_REGRESSION)),
        ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS(WarningsContext.WarningType.ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS, CAUSAL_PROPENSITY_CHECKS, new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_BINARY_CLASSIFICATION), new MLTaskConfig(MLTask.MLTaskType.PREDICTION, MLTask.BackendType.PY_MEMORY, PredictionMLTask.PredictionType.CAUSAL_REGRESSION)),
        LLM_EVALUATION_COMPUTATION_ERROR(WarningsContext.WarningType.LLM_EVALUATION_COMPUTATION_ERROR, LLM_EVALUATION_COMPUTATION_CHECKS, new MLTaskConfig(null, null, null), new MLTaskConfig(null, null, null));

        public final DiagnosticDefinition diagnosticDefinition;
        public final WarningsContext.WarningType warningType;
        private final List<MLTaskConfig> supportedMLTaskConfigs;

        private DiagnosticsTypes(WarningsContext.WarningType warningType, DiagnosticDefinition diagnosticDefinition, MLTaskConfig ... supportedMLTaskConfigs) {
            this.warningType = warningType;
            this.diagnosticDefinition = diagnosticDefinition;
            this.supportedMLTaskConfigs = Arrays.asList(supportedMLTaskConfigs);
        }

        public static DiagnosticsTypes getByWarningType(WarningsContext.WarningType warningType) {
            for (DiagnosticsTypes x : DiagnosticsTypes.values()) {
                if (!x.warningType.equals((Object)warningType)) continue;
                return x;
            }
            throw new IllegalArgumentException();
        }
    }

    public static class DoctorDiagnostic {
        public WarningsContext.WarningType type;
        public String message;
        public String step;
        @JSON.FileTransient
        public String displayableType;

        static {
            JSON.registerAdapter(DoctorDiagnostic.class, (Object)new JsonDeserializer<DoctorDiagnostic>(){

                public DoctorDiagnostic deserialize(JsonElement jsonElement, Type scriptType, JsonDeserializationContext ctx) throws JsonParseException {
                    JsonObject obj = jsonElement.getAsJsonObject();
                    DoctorDiagnostic diagnostic = new DoctorDiagnostic();
                    diagnostic.type = (WarningsContext.WarningType)ctx.deserialize(obj.get("type"), WarningsContext.WarningType.class);
                    diagnostic.message = obj.get("message").getAsString();
                    diagnostic.step = obj.get("step").getAsString();
                    diagnostic.displayableType = DiagnosticsTypes.getByWarningType((WarningsContext.WarningType)diagnostic.type).diagnosticDefinition.displayableType;
                    return diagnostic;
                }
            });
        }
    }
}

