import numpy as np
import logging

logger = logging.getLogger(__name__)


def compute_causal_model_variables_importance(dku_causal_model, X_mf):
    """
    Recipe for variables importance computation:
    - Get predicted CATE from a causal model
    - Train a random forest regressor in order to approximate the predictions
    - Extract the variables importance from the "surrogate" model
    """
    features = X_mf.columns()
    X_np = X_mf.as_np_array()
    if not hasattr(dku_causal_model, "predict_effect"):
        # Multi-treatment: variables importance ill-defined
        return None

    pred = dku_causal_model.predict_effect(X_np)

    from sklearn.ensemble import RandomForestRegressor

    logger.info("Training a RandomForestRegressor surrogate model to compute variables importance")
    surrogate_model = RandomForestRegressor(min_samples_leaf=10, max_depth=7, n_jobs=4, n_estimators=25, random_state=1337)
    surrogate_model.fit(X_np, pred)
    logger.info("Done")
    feature_importances = surrogate_model.feature_importances_
    coefs = {"variables": [], "importances": []}
    for v, i in zip(features, feature_importances):
        if not np.isnan(i):
            coefs["variables"].append(v)
            coefs["importances"].append(i if i >= 0 else 0.0)
    return coefs


def compute_cumulative_sums(y, t_binary, cate, sample_weights=None):
    """
    :param ndarray y: dtype float (regressions) or bool/int classification
    :param ndarray t_binary: dtype bool or int
    :param ndarray cate: dtype float
    :param ndarray sample_weights: dtype float
    :return: 4 ndarray's (cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    """
    # sort the data by decreasing CATE
    sorted_index = np.argsort(cate)[::-1]
    y_sorted_by_cate = y[sorted_index]
    t_binary_sorted_by_cate = t_binary[sorted_index]
    if sample_weights is None:
        sample_weights_sorted_by_cate = np.ones(t_binary_sorted_by_cate.shape[0])
    else:
        sample_weights_sorted_by_cate = sample_weights[sorted_index]
    # compute the cumulated numbers of treated and control

    cum_n_treated = np.cumsum(t_binary_sorted_by_cate * sample_weights_sorted_by_cate)
    cum_n_control = np.cumsum((1 - t_binary_sorted_by_cate) * sample_weights_sorted_by_cate)
    # compute the cumulated sum of target, for treated and control
    cum_y_treated = np.cumsum(t_binary_sorted_by_cate * y_sorted_by_cate * sample_weights_sorted_by_cate)
    cum_y_control = np.cumsum((1 - t_binary_sorted_by_cate) * y_sorted_by_cate * sample_weights_sorted_by_cate)
    return cum_n_control, cum_n_treated, cum_y_control, cum_y_treated


def _compute_common_term_of_uplift_curves(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated):
    partial_avg_y_treated = cum_y_treated / cum_n_treated
    partial_avg_y_control = cum_y_control / cum_n_control

    # Replace potentially ill-defined values in each term by their respective means
    # TODO @causal: update approach for multi treatment
    avg_y_treated = cum_y_treated[-1] / cum_n_treated[-1]
    partial_avg_y_treated_nan_mask = np.isnan(partial_avg_y_treated)
    nb_nan_partial_avg_y_treated = np.count_nonzero(partial_avg_y_treated_nan_mask)
    if nb_nan_partial_avg_y_treated > 0:
        logger.info(
            "Average treated outcome is undefined for the first {} points of the curve. Using instead the average on the whole treated population: {}.".format(
                nb_nan_partial_avg_y_treated, avg_y_treated
            )
        )
        partial_avg_y_treated[partial_avg_y_treated_nan_mask] = avg_y_treated

    partial_avg_y_control_nan_mask = np.isnan(partial_avg_y_control)
    nb_nan_partial_avg_y_control = np.count_nonzero(partial_avg_y_control_nan_mask)
    avg_y_control = partial_avg_y_control[~partial_avg_y_control_nan_mask][-1]
    if nb_nan_partial_avg_y_control > 0:
        logger.info(
            "Average control outcome is undefined for the first {} points of the curve. Using instead the average on the whole control population: {}.".format(
                nb_nan_partial_avg_y_control, avg_y_control
            )
        )
        partial_avg_y_control[partial_avg_y_control_nan_mask] = avg_y_control
    return partial_avg_y_treated - partial_avg_y_control


def compute_uplift_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated):
    """
    :param ndarray cum_n_control: dtype float
    :param ndarray cum_n_treated: dtype float
    :param ndarray cum_y_control: dtype float
    :param ndarray cum_y_treated: dtype float
    :return: ndarray qini_curve: the uplift curve y-values, normalized by the ATE
    """
    partial_avg_y_diff = _compute_common_term_of_uplift_curves(
        cum_n_control, cum_n_treated, cum_y_control, cum_y_treated
    )

    return partial_avg_y_diff * (cum_n_treated + cum_n_control) / (cum_n_treated[-1] + cum_n_control[-1])


def compute_qini_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated):
    """
    :param ndarray cum_n_control: dtype float
    :param ndarray cum_n_treated: dtype float
    :param ndarray cum_y_control: dtype float
    :param ndarray cum_y_treated: dtype float
    :return: ndarray qini_curve: the Qini curve y-values, normalized by the ATE
    """
    partial_avg_y_diff = _compute_common_term_of_uplift_curves(
        cum_n_control, cum_n_treated, cum_y_control, cum_y_treated
    )

    return partial_avg_y_diff * cum_n_treated / cum_n_treated[-1]


def compute_cate_histograms(cate):
    """
    :param ndarray cate: dtype float
    :return: list of dict distrib_cate
    """
    distrib_cate = []
    counts, thresholds = np.histogram(cate, bins=25)
    for idx, count in enumerate(counts):
        distrib_cate.append({
            "bin_id": idx,
            "bin_min": thresholds[idx],
            "bin_max": thresholds[idx+1],
            "count": count
        })
    return distrib_cate


def compute_auuc_score(y, t_binary, cate, normalized=True, sample_weights=None):
    """
    :param ndarray y: dtype float (regressions) or bool/int classification
    :param ndarray t_binary: dtype bool or int
    :param ndarray cate: dtype float
    :param bool normalized: whether to normalize the AUUC by the test ATE
    :return: float auuc_score: the (normalized or not) area under the uplift curve (AUUC), with the convention of subtracting the area under the line
        representing random treatment assignment.
    """
    if sample_weights is not None:
        mask = np.isfinite(sample_weights)
        sample_weights=sample_weights[mask]
        y = y[mask]
        t_binary = t_binary[mask]
        cate = cate[mask]
    cum_n_control, cum_n_treated, cum_y_control, cum_y_treated = compute_cumulative_sums(y, t_binary, cate, sample_weights=sample_weights)
    uplift_curve = compute_uplift_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    test_ate = uplift_curve[-1]

    area_under_uplift_curve = sum(uplift_curve) / uplift_curve.size
    area_under_random_treatment_line = test_ate / 2
    auuc_score = area_under_uplift_curve - area_under_random_treatment_line

    if normalized:
        return auuc_score / abs(test_ate)
    return auuc_score


def compute_qini_score(y, t_binary, cate, normalized=True, sample_weights=None):
    """
    :param ndarray y: dtype float (regressions) or bool/int classification
    :param ndarray t_binary: dtype bool or int
    :param ndarray cate: dtype float
    :param bool normalized: whether to normalize the Qini score by the test ATE
    :return: float qini_score: the (normalized or not) Qini score, with the convention of subtracting the area under the line representing random
        treatment assignment.
    """
    cum_n_control, cum_n_treated, cum_y_control, cum_y_treated = compute_cumulative_sums(y, t_binary, cate, sample_weights)
    qini_curve = compute_qini_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    test_ate = qini_curve[-1]

    area_under_qini_curve = sum(qini_curve) / qini_curve.size
    area_under_random_treatment_line = test_ate / 2
    qini_score = area_under_qini_curve - area_under_random_treatment_line

    if normalized:
        return qini_score / abs(test_ate)
    return qini_score


def compute_net_uplift_score(y, t_binary, cate, level, normalized=True, sample_weights=None):
    """
    :param ndarray y: dtype float (regressions) or bool/int classification
    :param ndarray t_binary: dtype bool or int
    :param ndarray cate: dtype float
    :param float level: ranges from 0 to 1
    :param bool normalized: whether to normalize the net uplift score by the test ATE
    :return: float uplift_score: the (normalized or not) uplift score, with the convention of subtracting the random assignment uplift score.
    """
    cum_n_control, cum_n_treated, cum_y_control, cum_y_treated = compute_cumulative_sums(y, t_binary, cate, sample_weights)
    uplift_curve = compute_uplift_curve(cum_n_control, cum_n_treated, cum_y_control, cum_y_treated)
    test_ate = uplift_curve[-1]

    index = int((len(uplift_curve) - 1) * level)
    uplift_score = uplift_curve[index]
    random_assignment_uplift_score = level * test_ate
    net_uplift_score = uplift_score - random_assignment_uplift_score
    if normalized:
        return net_uplift_score / abs(test_ate)
    return net_uplift_score


def get_causal_scorer(modeling_params):
    evaluation_metric = modeling_params["metrics"]["evaluationMetric"]
    if evaluation_metric == "AUUC":
        causal_scorer = AUUCCausalScorer()
    elif evaluation_metric == "QINI":
        causal_scorer = QiniCausalScorer()
    elif evaluation_metric == "NET_UPLIFT":
        if "netUpliftPoint" in modeling_params["metrics"]:
            net_uplift_point = modeling_params["metrics"]["netUpliftPoint"]
        else:
            net_uplift_point = 0.5
            logger.warning("Undefined net uplift point, using {} as default value".format(net_uplift_point))
        causal_scorer = UpliftCausalScorer(net_uplift_point)
    else:
        raise ValueError("Unsupported causal metrics: {}".format(evaluation_metric))
    return causal_scorer


class CausalScorer(object):
    def __call__(self, dku_causal_model, X, y_true, t_binary, sample_weights=None):
        assert len(np.unique(t_binary)) <= 2, "Binarized treatment expected in causal scorer"
        cate = dku_causal_model.predict_effect(X)
        return self._score(y_true, t_binary, cate, sample_weights=sample_weights)

    def _score(self, y_true, t_binary, cate, sample_weights=None):
        raise NotImplementedError


class AUUCCausalScorer(CausalScorer):
    def _score(self, y_true, t_binary, cate, sample_weights=None):
        return compute_auuc_score(y_true, t_binary, cate, sample_weights=sample_weights)


class QiniCausalScorer(CausalScorer):
    def _score(self, y_true, t_binary, cate, sample_weights=None):
        return compute_qini_score(y_true, t_binary, cate, sample_weights=sample_weights)


class UpliftCausalScorer(CausalScorer):
    def __init__(self, level):
        self.level = level

    def _score(self,  y_true, t_binary, cate, sample_weights=None):
        return compute_net_uplift_score(y_true, t_binary, cate, self.level, sample_weights=sample_weights)
