import numpy as np
import pandas as pd
from itertools import product
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType


TREATMENT_IMBALANCE_THRESHOLD = 0.2
TREATMENT_OUTCOME_IMBALANCE_THRESHOLD = 0.05
TREATMENT_RANDOMNESS_AUC_THRESHOLD = 0.65
TREATMENT_POSITIVITY_RATIO_THRESHOLD = 0.05
TREATMENT_POSITIVITY_RATIO_SMOOTHING_BETA = 0.1
TREATMENT_PROPENSITY_CALIBRATION_THRESHOLD = 0.2


def is_diagnostic_enabled(diagnostics_params, diagnostics_type):
    if diagnostics_params.get("enabled", False):
        diagnostics_params_dict = {x["type"]: x["enabled"] for x in diagnostics_params.get("settings")}
        return diagnostics_params_dict.get(diagnostics_type)
    else:
        return False


def check_treatment_randomness(diagnostics_params, propensity_scores, treatment_map=None):
    """ Check that the treatment cannot be predicted by the input variables.
        This is done by checking the performance of the propensity model."""
    if not is_diagnostic_enabled(diagnostics_params, "ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS"):
        return
    if treatment_map is None:
        # Binary treatment
        test_propensity_auc = propensity_scores["propensityPerf"]["tiMetrics"]["auc"]
        if test_propensity_auc > TREATMENT_RANDOMNESS_AUC_THRESHOLD:
            diagnostic_message = "The treatment is not random: the propensity model performed significantly better than a random model"
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)
    else:
        # Multi-valued treatment
        for treatment, _ in treatment_map.items_except_control():
            test_propensity_auc = propensity_scores["propensityPerfMulti"][treatment]["tiMetrics"]["auc"]
            if test_propensity_auc > TREATMENT_RANDOMNESS_AUC_THRESHOLD:
                diagnostic_message = "The treatment \"{}\" is not random: the propensity model performed significantly better than a random model".format(treatment)
                diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)


def check_treatment_positivity(diagnostics_params, propensity_scores, treatment_map=None):
    """ Check that no group of individuals lacks exposure to any treatment.
        This is done by comparing the probability distribution supports of the propensity scores (approximated by histograms) for each treatment.
        For multi-valued treatments, the control is used as a pivot, in order to avoid checking all pairs."""
    if not is_diagnostic_enabled(diagnostics_params, "ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS"):
        return
    if treatment_map is None:
        # Binary treatment
        test_prediction_distribs = propensity_scores["propensityPerf"]["probaDistribData"]["probaDistribs"]
        test_histogram_smoothed_ratios = (np.array(test_prediction_distribs[0]) + TREATMENT_POSITIVITY_RATIO_SMOOTHING_BETA) / (np.array(test_prediction_distribs[1]) + TREATMENT_POSITIVITY_RATIO_SMOOTHING_BETA)
        if any(test_histogram_smoothed_ratios < TREATMENT_POSITIVITY_RATIO_THRESHOLD):
            diagnostic_message = "Some groups of the test population have a very low probability of being exposed to the control"
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)
        if any(test_histogram_smoothed_ratios < 1 / TREATMENT_POSITIVITY_RATIO_THRESHOLD):
            diagnostic_message = "Some groups of the test population have a very low probability of being exposed to the treatment"
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)
    else:
        # Multi-valued treatment
        positivity_failure_control_group = False
        for treatment, _ in treatment_map.items_except_control():
            test_prediction_distribs = propensity_scores["propensityPerfMulti"][treatment]["probaDistribData"]["probaDistribs"]
            test_histogram_smoothed_ratios = (np.array(test_prediction_distribs[0]) + TREATMENT_POSITIVITY_RATIO_SMOOTHING_BETA) / (np.array(test_prediction_distribs[1]) + TREATMENT_POSITIVITY_RATIO_SMOOTHING_BETA)
            positivity_failure_control_group = positivity_failure_control_group or any(test_histogram_smoothed_ratios < TREATMENT_POSITIVITY_RATIO_THRESHOLD)
            if any(test_histogram_smoothed_ratios < 1 / TREATMENT_POSITIVITY_RATIO_THRESHOLD):
                diagnostic_message = "Some groups of the test population have a very low probability of being exposed to the treatment \"{}\"".format(treatment)
                diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)
        if positivity_failure_control_group:
            diagnostic_message = "Some groups of the test population have a very low probability of being exposed to the control"
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_TREATMENT_CHECKS, diagnostic_message)


def check_propensity_model_calibration(diagnostics_params, propensity_scores, treatment_map=None):
    """ Check that the propensity model is well calibrated.
        This is done by checking the calibration loss of the propensity models metrics."""
    if not is_diagnostic_enabled(diagnostics_params, "ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS"):
        return
    if treatment_map is None:
        # Binary treatment
        test_propensity_calibration_loss = propensity_scores["propensityPerf"]["tiMetrics"]["calibrationLoss"]
        if test_propensity_calibration_loss > TREATMENT_PROPENSITY_CALIBRATION_THRESHOLD:
            diagnostic_message = "The propensity model is poorly calibrated: consider enabling calibration or increasing the calibration data ratio"
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS, diagnostic_message)
    else:
        for treatment, _ in treatment_map.items_except_control():
            test_propensity_calibration_loss = propensity_scores["propensityPerfMulti"][treatment]["tiMetrics"]["calibrationLoss"]
            if test_propensity_calibration_loss > TREATMENT_PROPENSITY_CALIBRATION_THRESHOLD:
                diagnostic_message = "The propensity model for treatment \"{}\" is poorly calibrated: consider enabling calibration or increasing the calibration data ratio".format(treatment)
                diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_CAUSAL_PROPENSITY_CHECKS, diagnostic_message)


def check_imbalanced_treatment(diagnostics_params, treatment_series, treatment_map=None):
    """ Check that no treatment (including control) group is present with a potentially problematic low frequency."""
    if not is_diagnostic_enabled(diagnostics_params, "ML_DIAGNOSTICS_DATASET_SANITY_CHECKS"):
        return
    treatment = treatment_series.values
    treatment = treatment[np.isfinite(treatment)]
    unique_treatments, counts = np.unique(treatment, return_counts=True)
    n_treatments = unique_treatments.size
    total = np.sum(counts)
    if np.any(counts / total * n_treatments < TREATMENT_IMBALANCE_THRESHOLD):
        mask = counts / total * n_treatments < TREATMENT_IMBALANCE_THRESHOLD
        treatments_below_threshold = unique_treatments[mask]
        pretty_treatments = [pretty_treatment(t, treatment_map) for t in treatments_below_threshold]
        diagnostic_prefix = "A treatment occurs" if len(pretty_treatments) == 1 else "Several treatments occur"
        diagnostic_message = diagnostic_prefix + " in less than {}% of cases in the dataset: {}".format(int(100 * TREATMENT_IMBALANCE_THRESHOLD / n_treatments), pretty_treatments)
        diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, diagnostic_message)


def check_imbalanced_treatment_outcome(diagnostics_params, treatment_series, target_series, is_regression, treatment_map=None, target_map=None):
    """ Check that no combination of a treatment (including control) group and an outcome (quantized in 4 quartiles for regressions)
        is present with a potentially problematic low frequency."""
    if not is_diagnostic_enabled(diagnostics_params, "ML_DIAGNOSTICS_DATASET_SANITY_CHECKS"):
        return
    if is_regression:
        quantized_target, bins = pd.qcut(target_series, 4, labels=False, duplicates='drop', retbins=True)
        quantized_target += 1
        n_q_target = len(bins) - 1
    else:
        quantized_target = target_series
        n_q_target = 2
    n_treatments = treatment_series.unique().size
    df = pd.DataFrame({"treatment": treatment_series, "target": quantized_target})
    group = df.groupby(["treatment", "target"]).size()
    if group.size < n_q_target * n_treatments:
        # Missing (targets x treatments) group
        missing_set = set(product(df["treatment"].unique(), df["target"].unique())) - set(group.index)
        n_missing = len(missing_set)
        if is_regression:
            pretty_missing_set = pretty_combinations(missing_set, treatment_map=treatment_map, is_regression=True)
            diagnostic_message = "{} (treatment, outcome quartile) combinations missing: {}".format(n_missing, pretty_missing_set)
        else:
            pretty_missing_set = pretty_combinations(missing_set, treatment_map=treatment_map, is_regression=False, target_map=target_map)
            diagnostic_message = "{} (treatment, outcome) combinations missing: {}".format(n_missing, pretty_missing_set)
        diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, diagnostic_message)
        return
    if np.min(group.values) / treatment_series.size < TREATMENT_OUTCOME_IMBALANCE_THRESHOLD:
        below_threshold_set = pretty_combinations(set(group.index[group.values / treatment_series.size < TREATMENT_OUTCOME_IMBALANCE_THRESHOLD]), treatment_map=treatment_map, is_regression=is_regression ,target_map=None)
        diagnostic_message = "{} (treatment, outcome) combinations occur in less than {}% of cases in the dataset: {}".format(len(below_threshold_set), int(100 * TREATMENT_OUTCOME_IMBALANCE_THRESHOLD), below_threshold_set)
        diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, diagnostic_message)


def pretty_combinations(combinations, treatment_map=None, is_regression=True, target_map=None):
    if is_regression:
        return [(pretty_treatment(t, treatment_map), pretty_quartile(q)) for (t, q) in combinations]
    else:
        return [(pretty_treatment(t, treatment_map), pretty_class(c, target_map)) for (t, c) in combinations]


def pretty_quartile(n):
    return {1: "1st", 2: "2nd", 3: "3rd", 4: "4th"}[n] + " quartile"


def pretty_class(c, target_map=None):
    if target_map is None:
        return c
    else:
        return {v: k for k, v in target_map.items()}[c]


def pretty_treatment(t, treatment_map=None):
    if treatment_map is None:
        return "control" if int(t) == 0 else "treated"
    else:
        return {v: k for k, v in treatment_map.items()}[int(t)]
