import logging

import numpy as np
import pandas as pd
from statsmodels.stats import proportion

from dataiku.core import doctor_constants
from dataiku.core.scipycompat import binom_test
from dataiku.doctor.causal.utils.metrics import compute_auuc_score
from dataiku.doctor.causal.utils.metrics import compute_cate_histograms
from dataiku.doctor.causal.utils.metrics import compute_causal_model_variables_importance
from dataiku.doctor.causal.utils.metrics import compute_cumulative_sums
from dataiku.doctor.causal.utils.metrics import compute_net_uplift_score
from dataiku.doctor.causal.utils.metrics import compute_qini_curve
from dataiku.doctor.causal.utils.metrics import compute_qini_score
from dataiku.doctor.causal.utils.metrics import compute_uplift_curve
from dataiku.doctor.causal.utils.models import get_predictions_from_causal_model_single_treatment
from dataiku.doctor.causal.utils.models import get_predictions_from_causal_model_multi_treatment
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.scoring_base import PredictionModelScorer
from dataiku.doctor.utils import remove_all_nan
from dataiku.doctor.utils import dku_nonaninf
from dataiku.doctor.utils.metrics import check_test_set_ok_for_classification

PREDICTED_EFFECT_COLUMN = "predicted_effect"
logger = logging.getLogger(__name__)


def causal_prediction_scorer_with_valid(modeling_params, dku_causal_model, valid, out_folder_context, input_df_index,
                                        is_regression, propensity_model=None, treatment_map=None):
    if is_regression:
        valid_y = valid["target"].astype(float)
    else:
        valid_y = valid["target"].astype(int)
        check_test_set_ok_for_classification(valid_y)
    valid_t = valid["treatment"]
    valid_proba_t = None
    if propensity_model is not None:
        valid_proba_t = propensity_model.predict_proba(valid["TRAIN"].as_np_array())
    if treatment_map is None:
        # Binary treatment
        valid_cate, _, _, _ = get_predictions_from_causal_model_single_treatment(dku_causal_model, valid)
        return BinaryTreatmentCausalPredictionModelScorer(modeling_params=modeling_params, out_folder_context=out_folder_context,
                                                          test_cate=valid_cate, test_y=valid_y, test_t=valid_t,
                                                          test_proba_t=valid_proba_t, test_X=valid["TRAIN"], test_df_index=input_df_index)
    else:
        # Multi-value treatment
        valid_cates_dict = get_predictions_from_causal_model_multi_treatment(dku_causal_model, valid)
        return MultiValueTreatmentCausalPredictionModelScorer(modeling_params=modeling_params, out_folder_context=out_folder_context,
                                                              test_cates_dict=valid_cates_dict, test_y=valid_y, test_t=valid_t,
                                                              test_proba_t=valid_proba_t, test_X=valid["TRAIN"], test_df_index=input_df_index,
                                                              treatment_map=treatment_map)


class CausalModelIntrinsicScorer(object):

    # TODO: inherit from PredictionModelIntrinsicScorer once refactored ?
    def __init__(self, dku_causal_model, train_X, out_folder_context, initial_intrinsic_perf_data):
        self.train_X = train_X
        self.out_folder_context = out_folder_context
        self.dku_causal_model = dku_causal_model
        self.initial_intrinsic_perf_data = initial_intrinsic_perf_data

    def score_and_save(self):
        ret = self.initial_intrinsic_perf_data
        #  Variables importance are only properly defined for binary treatment models
        if not self.dku_causal_model.handles_multi_value_treatment():
            coefs = compute_causal_model_variables_importance(self.dku_causal_model, self.train_X)
            if coefs:
                ret["rawImportance"] = coefs
        self.out_folder_context.write_json("iperf.json", ret)


class AbstractPropensityScorer(object):
    """
    Computes the prediction performance of the propensity model, a model predicting the treatment variable,
    with the same input variables as the causal model. The perf data is used to assess the
    "positivity hypothesis", stating that control and treatment groups properly overlap in the feature space.
    Additionally, computes a statistical test of the hypothesis that the treatment variable is random.
    """
    def score(self):
        raise NotImplementedError()


class BinaryTreatmentPropensityScorer(AbstractPropensityScorer):

    def __init__(self, modeling_params, out_folder_context, decisions_and_cuts, test_t, treatment_map):
        self.bc_scorer = BinaryClassificationModelScorer(modeling_params=modeling_params,
                                                         out_folder_context=out_folder_context,
                                                         decisions_and_cuts=decisions_and_cuts, test_y=test_t,
                                                         target_map=treatment_map)

    confidence_level = 0.95

    def score(self):
        self.bc_scorer.score(with_assertions=False)

        accuracy_ref = self.bc_scorer.ret["globalMetrics"]["targetAvg"][0]
        accuracy_ref = max(accuracy_ref, 1-accuracy_ref)

        n_test = self.bc_scorer.ret["globalMetrics"]["testWeight"]
        accuracy = max(self.bc_scorer.ret["perCutData"]["accuracy"])
        accuracy_lower, accuracy_upper = proportion.proportion_confint(
            accuracy*n_test, n_test, method="wilson", alpha=(1-self.confidence_level)
        )
        p_value = binom_test(accuracy*n_test, n_test, p=accuracy_ref, alternative='greater')

        res = self.bc_scorer.ret
        res["binomialTreatmentTest"] = {
            "confidenceLevel": self.confidence_level,
            "nTest": n_test,
            "accuracyReference": accuracy_ref,
            "accuracy": accuracy,
            "accuracyLower": accuracy_lower,
            "accuracyUpper": accuracy_upper,
            "pValue": p_value,
        }
        return res


class MultiTreatmentPropensityScorer(AbstractPropensityScorer):
    """
    Computes the propensity model global performance as a multiclass classification model.
    Additionally, computes the propensity scoring data per treatment t, conditionally on treatment in {control, t}
    """
    def __init__(self, modeling_params, out_folder_context, proba_t, test_t, treatment_map):
        self.modeling_params = modeling_params
        self.out_folder_context = out_folder_context
        self.proba_t = proba_t
        self.test_t = test_t
        self.treatment_map = treatment_map

    def score(self):
        res = {}
        for k, v in self.treatment_map.items_except_control():
            mask = (self.test_t == 0) | (self.test_t == v)
            test_t_masked = self.test_t[mask] == v
            treatment_target_map = {"0": 0, "1": 1}
            # probability of test_t==control and probability of test_t==treatment, both conditionally on "test_t in {control, treatment}"
            test_proba_t_masked_normalized = self.proba_t[mask, v]/(self.proba_t[mask, 0] + self.proba_t[mask, v])
            decisions_and_cuts = DecisionsAndCuts.from_probas(np.vstack([1-test_proba_t_masked_normalized, test_proba_t_masked_normalized]).T, treatment_target_map)
            res[k] = BinaryTreatmentPropensityScorer(self.modeling_params, self.out_folder_context, decisions_and_cuts, test_t_masked, treatment_target_map).score()
        return res


class AbstractCausalPredictionModelScorer(PredictionModelScorer):

    def __init__(self, modeling_params, out_folder_context, test_y, test_t, test_X=None, test_df_index=None, test_proba_t=None):
        super(AbstractCausalPredictionModelScorer, self).__init__(modeling_params, test_X=test_X, assertions=None)
        self.test_y = test_y
        self.test_t = test_t
        self.test_df_index = test_df_index
        self.test_proba_t = test_proba_t
        self.out_folder_context = out_folder_context
        self.propensity_scorer = None

    def _do_score(self, with_assertions, treat_metrics_failure_as_error=True):
        # TODO use flag treat_metrics_failure_as_error to soft fail on causal metrics (sc-154901)
        self._prepare_predicted_df()
        self._compute_perf_data()
        return self.ret

    def _initialize_propensity_scorer(self):
        """Creates the adequate propensity scorer if probability of treatment is given"""
        raise NotImplementedError()

    def _compute_perf_data(self):
        """Computes the causal performance data"""
        raise NotImplementedError()

    def _prepare_predicted_df(self):
        """Computes the causal prediction DataFrame, with:
          - the predicted effect(s)
          - the propensities if relevant
          - proper indexing to account for dropped rows
        """
        raise NotImplementedError()


class BinaryTreatmentCausalPredictionModelScorer(AbstractCausalPredictionModelScorer):

    def __init__(self, modeling_params, out_folder_context, test_cate, test_y, test_t, test_proba_t=None, test_X=None, test_df_index=None):
        """
        :param dict modeling_params: storing the modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param dataiku.base.folder_context.FolderContext|None out_folder_context: directory where predicted data and perf.json will be written
        :param ndarray test_cate: prediction of the conditional average treatment effect (CATE)
        :param Series test_y: 1-dimensional array representing the ground truth target (class index) on the test set
        :param Series test_t: 1-dimensional array representing the treatment value on the test set
        :param DataFrame | None test_proba_t: prediction of proba of treatment if propensity is enabled as the result of a predict_proba call from a sklearn classifier
                                              If None, no propensity performance metrics will be computed
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
                                                                   If None, no data will be written on disk (e.g. training recipes).
        :param test_df_index: Pandas index of the input dataframe of the original test set, prior to any processing
        """

        super(BinaryTreatmentCausalPredictionModelScorer, self).__init__(modeling_params, out_folder_context, test_y, test_t,
                                                                         test_X=test_X, test_df_index=test_df_index, test_proba_t=test_proba_t)

        self.test_cate = test_cate
        if test_proba_t is not None:
            self._initialize_propensity_scorer()

    def _prepare_predicted_df(self):
        df = pd.DataFrame({PREDICTED_EFFECT_COLUMN: self.test_cate})
        if self.test_proba_t is not None:
            # propensity scores have been computed
            # test_proba_t is the result of a predict_proba call from a sklearn classifier which encodes the
            # probability of the positive class into the 1-indexed column
            df["propensity"] = self.test_proba_t[:, 1]

        df.index = self.test_X_index
        full = pd.DataFrame(index=self.test_df_index)
        df = full.join(df, how="left")
        self.predicted_df = df

    def _compute_perf_data(self):
        test_cate = self.test_cate.ravel()
        causal_metrics = _compute_causal_metrics_single_treatment(test_cate, self.test_t.to_numpy(), self.test_y.to_numpy(), self.test_proba_t,
                                                                  self.modeling_params["metrics"])
        self.ret["causalPerf"] = causal_metrics
        if self.propensity_scorer is not None:
            self.ret["propensityPerf"] = self.propensity_scorer.score()
        self.ret = remove_all_nan(self.ret)
        self.perf_data = self.ret

    def _initialize_propensity_scorer(self):
        treatment_target_map = {"0": 0, "1": 1}  # No remapping since treatment variable is preprocessed to mean 1 = treated and 0 = control
        decisions_and_cuts = DecisionsAndCuts.from_probas(self.test_proba_t, treatment_target_map)
        self.propensity_scorer = BinaryTreatmentPropensityScorer(self.modeling_params, self.out_folder_context, decisions_and_cuts,
                                                                 self.test_t, treatment_target_map)


class MultiValueTreatmentCausalPredictionModelScorer(AbstractCausalPredictionModelScorer):

    def __init__(self, modeling_params, out_folder_context, test_cates_dict, test_y, test_t, test_proba_t=None, test_X=None, test_df_index=None, treatment_map=None):
        """
        :param dict modeling_params: storing the modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param dataiku.base.folder_context.FolderContext|None out_folder_context: directory where predicted data and perf.json will be written
        :param dict test_cates_dict: predictions of the conditional average treatment effect (CATE)
        :param Series test_y: 1-dimensional array representing the ground truth target (class index) on the test set
        :param Series test_t: 1-dimensional array representing the treatment value on the test set
        :param dict target_map: map of named class (label) to class id in range(len(target_map))
        :param DataFrame | None test_proba_t: prediction of proba of treatment if propensity is enabled as the result of a predict_proba call from a sklearn classifier
                                              If None, no propensity performance metrics will be computed
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
                                                                   If None, no data will be written on disk (e.g. training recipes).
        :param test_df_index: Pandas index of the input dataframe of the original test set, prior to any processing
        """

        super(MultiValueTreatmentCausalPredictionModelScorer, self).__init__(modeling_params, out_folder_context, test_y, test_t,
                                                                             test_X=test_X, test_df_index=test_df_index,
                                                                             test_proba_t=test_proba_t)
        self.test_cates_dict = test_cates_dict
        self.treatment_map = treatment_map
        if test_proba_t is not None:
            self._initialize_propensity_scorer()

    def _prepare_predicted_df(self):
        df = pd.DataFrame({PREDICTED_EFFECT_COLUMN + "_" + t: cate for t, cate in self.test_cates_dict.items()})
        if self.test_proba_t is not None:
            # LogisticRegression.classes_ = np.unique(t) which returns the sorted unique elements of the array t
            for k, v in self.treatment_map.items():
                df["propensity_" + k] = self.test_proba_t[:, v]
        # Realign
        df.index = self.test_X_index
        full = pd.DataFrame(index=self.test_df_index)
        df = full.join(df, how="left")
        self.predicted_df = df

    def _compute_perf_data(self):
        # Uplift and Qini curves
        self.ret = _compute_causal_metrics_multi_treatment(self.test_y.to_numpy(), self.test_t.to_numpy(), self.test_cates_dict, self.test_proba_t,
                                                           self.treatment_map, self.modeling_params["metrics"])
        self.ret["causalPerf"] = self.ret["causalPerfMultiAll"]

        if self.propensity_scorer is not None:
            self.ret["propensityPerfMulti"] = self.propensity_scorer.score()
        self.ret = remove_all_nan(self.ret)
        self.perf_data = self.ret

    def _initialize_propensity_scorer(self):
        # LogisticRegression.classes_ = np.unique(t) which returns the sorted unique elements of the array t
        self.propensity_scorer = MultiTreatmentPropensityScorer(self.modeling_params, self.out_folder_context, self.test_proba_t,
                                                                self.test_t, self.treatment_map)


def _compute_cross_treatment_average_metrics(metrics, weights):
    """
    :param dict metrics: treatment -> metrics
    :param dict weights: treatment -> weight
    :return: float average metrics
    """
    assert metrics.keys() == weights.keys()
    sum = 0.
    total_weight = 0.
    for t in metrics.keys():
        if metrics[t] is None:
            continue
        sum += metrics[t] * weights[t]
        total_weight += weights[t]
    return sum / total_weight


def _compute_causal_metrics_single_treatment(test_cate, test_t, test_y, test_proba_t, metrics_params):
    """
    :param ndarray test_cate:
    :param ndarray test_t:
    :param ndarray test_y:
    :param ndarray or NoneType test_proba_t:
    :param dict metrics_params:
    :return:
    """
    with_ipw = metrics_params.get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY
    causal_metrics = dict()
    test_n_samples = test_y.shape[0]
    causal_metrics["testTotalPopSize"] = test_n_samples
    sample_weights = np.ones(test_n_samples)
    if with_ipw:
        assert test_proba_t is not None, "Inverse propensity weighting evaluation metrics require propensity scores"
        treatment_0_mask = test_t == 0
        treatment_1_mask = test_t == 1
        sample_weights[treatment_0_mask] = 1 / test_proba_t[treatment_0_mask, 0]
        sample_weights[treatment_1_mask] = 1 / test_proba_t[treatment_1_mask, 1]
        mask = np.isfinite(sample_weights)
        sample_weights = sample_weights[mask]
        test_y = test_y[mask]
        test_n_samples = test_y.shape[0]
        test_t = test_t[mask]
        test_cate = test_cate[mask]
    cum_n_control, cum_n_treated, cum_y_control, cum_y_treated = compute_cumulative_sums(
        test_y,
        test_t,
        test_cate,
        sample_weights=sample_weights)
    causal_metrics["testTreatedPopSize"] = cum_n_treated[-1]
    causal_metrics["testATE"] = cum_y_treated[-1] / cum_n_treated[-1] - cum_y_control[-1] / cum_n_control[-1]
    xaxis = (100 * np.arange(1, test_n_samples + 1) / test_n_samples)
    uplift_curve = compute_uplift_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    qini_curve = compute_qini_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    max_nb_data_points = 200
    if test_n_samples > max_nb_data_points:
        # only keep max_nb_data_points data points
        sampling_ratio = (test_n_samples - 1) / max_nb_data_points
        subsample_indices = (sampling_ratio * np.arange(1, max_nb_data_points + 1)).round().astype(int)
        xaxis = xaxis[subsample_indices]
        uplift_curve = uplift_curve[subsample_indices]
        qini_curve = qini_curve[subsample_indices]
    causal_metrics["upliftGainCurve"] = [{"x": x, "y": y} for x, y in zip(xaxis, uplift_curve)]
    causal_metrics["qiniCurve"] = [{"x": x, "y": y} for x, y in zip(xaxis, qini_curve)]
    # TODO: baseline with prediction when available ?
    # CATE histogram
    causal_metrics["cateHistogram"] = compute_cate_histograms(test_cate)
    # AUUC, QINI and net uplift
    if "netUpliftPoint" in metrics_params:
        net_uplift_point = metrics_params["netUpliftPoint"]
    else:
        net_uplift_point = 0.5
        logger.warning("Undefined net uplift point, using {} as default value".format(net_uplift_point))
    causal_metrics["netUpliftPoint"] = net_uplift_point
    causal_metrics["raw"] = {
        "auuc": dku_nonaninf(compute_auuc_score(test_y, test_t, test_cate, normalized=False)),
        "qini": dku_nonaninf(compute_qini_score(test_y, test_t, test_cate, normalized=False)),
        "netUplift": dku_nonaninf(
            compute_net_uplift_score(test_y, test_t, test_cate, net_uplift_point, normalized=False)),
    }
    causal_metrics["normalized"] = {
        metrics_name: dku_nonaninf(causal_metrics["raw"][metrics_name] / abs(causal_metrics["testATE"]))
        if causal_metrics["raw"][metrics_name] is not None else None
        for metrics_name in causal_metrics["raw"]
    }
    return causal_metrics


def _compute_causal_metrics_multi_treatment(test_y, test_t, test_cates_dict, test_proba_t, treatment_map, metrics_params):
    """
    :param ndarray test_y:
    :param ndarray test_t:
    :param dict test_cates_dict:
    :param ndarray or NoneType test_proba_t:
    :param TreatmentMap treatment_map:
    :param dict metrics_params:
    :return:
    """
    with_ipw = metrics_params.get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY
    if with_ipw:
        assert test_proba_t is not None, "Inverse propensity weighting evaluation metrics require propensity scores"
    # Method 1:
    #  - for each treatment t, compute causal metrics on the union of control group and treatment=t group
    #  - average across treatments with weight = cardinal(treatment=t)
    causal_metrics = {}
    causal_metrics["causalPerfMultiPerTreatmentAll"] = {}
    weights = {}
    for t, i in treatment_map.items_except_control():
        weights[t] = np.sum(test_t == i)
        mask = (test_t == 0) | (test_t == i)
        test_cate_single = test_cates_dict[t][mask]
        test_y_single = test_y[mask]
        test_t_single = (test_t[mask] == i)
        if with_ipw:
            test_proba_t_single = test_proba_t[mask][:, [0,i]]
        else:
            test_proba_t_single = None
        causal_metrics_single = _compute_causal_metrics_single_treatment(test_cate_single, test_t_single, test_y_single, test_proba_t_single, metrics_params)
        causal_metrics["causalPerfMultiPerTreatmentAll"][t] = causal_metrics_single
    causal_metrics["causalPerfMultiAll"] = {
        "raw": {
            "auuc": _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentAll"][t]["raw"]["auuc"] for t in
                                                              causal_metrics["causalPerfMultiPerTreatmentAll"].keys()}, weights),
            "qini": _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentAll"][t]["raw"]["qini"] for t in
                                                              causal_metrics["causalPerfMultiPerTreatmentAll"].keys()}, weights),
            "netUplift": None,  # TODO: ???
        },
    }
    weightedATE = _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentAll"][t]["testATE"] for t in causal_metrics["causalPerfMultiPerTreatmentAll"].keys()}, weights)
    causal_metrics["causalPerfMultiAll"]["testATE"] = weightedATE
    causal_metrics["causalPerfMultiAll"]["normalized"] = {
        metrics_name: dku_nonaninf(causal_metrics["causalPerfMultiAll"]["raw"][metrics_name] / weightedATE)
        if causal_metrics["causalPerfMultiAll"]["raw"][metrics_name] is not None else None
        for metrics_name in causal_metrics["causalPerfMultiAll"]["raw"]
    }

    # Method 2:
    #  - for each treatment t, compute causal metrics on the union of control group and the group treatment=t with t also being the most effective predicted treatment
    #  - average across treatments with weights = cardinal(treatment=t & max(effects) = effects[t])
    causal_metrics["causalPerfMultiPerTreatmentRestrict"] = {}
    weightsAlt = {}
    test_cate_arr = np.vstack([test_cates_dict[t] for t, i in treatment_map.items_except_control() if i != 0])
    max_test_cate = np.max(test_cate_arr, axis=0)
    for t, i in treatment_map.items_except_control():
        weightsAlt[t] = np.sum((test_t == i) & (test_cates_dict[t] == max_test_cate))
        mask = (test_t == 0) | ((test_t == i) & (test_cates_dict[t] == max_test_cate))
        test_cate_single = test_cates_dict[t][mask]
        test_y_single = test_y[mask]
        test_t_single = (test_t[mask] == i)
        if with_ipw:
            test_proba_t_single = test_proba_t[mask][:, [0,i]]
        else:
            test_proba_t_single = None
        causal_metrics_single = _compute_causal_metrics_single_treatment(test_cate_single, test_t_single, test_y_single, test_proba_t_single, metrics_params)
        causal_metrics["causalPerfMultiPerTreatmentRestrict"][t] = causal_metrics_single

    causal_metrics["causalPerfMultiRestrict"] = {
        "raw": {
            "auuc": _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentRestrict"][t]["raw"]["auuc"] for t in causal_metrics["causalPerfMultiPerTreatmentRestrict"].keys()}, weightsAlt),
            "qini": _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentRestrict"][t]["raw"]["qini"] for t in causal_metrics["causalPerfMultiPerTreatmentRestrict"].keys()}, weightsAlt),
            "netUplift": None,  # TODO: ???
        },
    }
    weightedATE = _compute_cross_treatment_average_metrics({t: causal_metrics["causalPerfMultiPerTreatmentRestrict"][t]["testATE"] for t in causal_metrics["causalPerfMultiPerTreatmentRestrict"].keys()}, weightsAlt)
    causal_metrics["causalPerfMultiRestrict"]["testATE"] = weightedATE
    causal_metrics["causalPerfMultiRestrict"]["normalized"] = {
        metrics_name: dku_nonaninf(causal_metrics["causalPerfMultiRestrict"]["raw"][metrics_name] / weightedATE)
        if causal_metrics["causalPerfMultiRestrict"]["raw"][metrics_name] is not None else None
        for metrics_name in causal_metrics["causalPerfMultiRestrict"]["raw"]
    }
    return causal_metrics
