import logging
import numpy as np
from econml._cate_estimator import BaseCateEstimator
from econml.grf import CausalForest
from econml.metalearners import SLearner, TLearner, XLearner
from econml.utilities import broadcast_unit_treatments, transpose, check_inputs, inverse_onehot, check_models
import sklearn
from sklearn import clone
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import check_array
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.calibration import CalibratedClassifierCV

from dataiku.base.utils import package_is_at_least
from dataiku.core import doctor_constants
from dataiku.doctor.causal.utils.metrics import get_causal_scorer
from dataiku.doctor.crossval.search_runner import CausalSearchRunner
from dataiku.doctor.prediction.common import get_input_hyperparameter_space, build_cv, get_selection_mode, safe_del
from dataiku.doctor.prediction.common import HyperparametersSpace, IntegerHyperparameterDimension, TrainableModel, \
    CategoricalHyperparameterDimension, TabularPredictionAlgorithm
from dataiku.doctor.prediction.classification_fit import CLASSIFICATION_ALGORITHMS
from dataiku.doctor.prediction.regression_fit import REGRESSION_ALGORITHMS
from dataiku.doctor.utils.skcompat import dku_calibrated_classifier_cv

logger = logging.getLogger(__name__)

"""
The classes below are meant to allow the use of the predict_proba method of classifiers by meta-learners.
This is required since econml uses the predict method regardless of the context (classification or regression).
We simply edit code of each of the wrapped class to replace the calls to `predict` method by calls to `predict_proba`
method of underlying classifiers.
In the fit method of DkuXLearnerClassifier the asterisk argument syntax is removed for Python 2 compatibility.
Permalink for econml version 0.13.0: https://github.com/microsoft/EconML/blob/v0.13.0/econml/metalearners/_metalearners.py
"""


class DkuSLearnerClassifier(SLearner):

    def const_marginal_effect(self, X=None):
        if X is None:
            X = np.zeros((1, 1))
        X = check_array(X)
        Xs, Ts = broadcast_unit_treatments(X, self._d_t[0] + 1)
        feat_arr = np.concatenate((Xs, Ts), axis=1)
        # Replace:
        # prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y)
        prediction = self.overall_model.predict_proba(feat_arr)[:,1].reshape((-1, self._d_t[0] + 1,) + self._d_y)
        if self._d_y:
            prediction = transpose(prediction, (0, 2, 1))
            taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:]
        else:
            taus = (prediction - np.repeat(prediction[:, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, 1:]
        return taus


class DkuTLearnerClassifier(TLearner):

    def const_marginal_effect(self, X):
        X = check_array(X)
        taus = []
        for ind in range(self._d_t[0]):
            # Replace:
            # taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X))
            taus.append(self.models[ind + 1].predict_proba(X)[:,1] - self.models[0].predict_proba(X)[:,1])
        taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y)  # shape as of m*d_t*d_y
        if self._d_y:
            taus = transpose(taus, (0, 2, 1))  # shape as of m*d_y*d_t
        return taus


class DkuXLearnerClassifier(XLearner):

    @BaseCateEstimator._wrap_fit
    # TODO: revert removal of asterisk argument (+ check for X against None) when support for Python 2 is dropped
    def fit(self, Y, T, X=None, inference=None):
        if X is None:
            raise ValueError("fit method of DkuXLearnerClassifier requires the X keyword argument to be specified as a ndarray")
        # Check inputs
        Y, T, X, _ = check_inputs(Y, T, X, multi_output_T=False)
        if Y.ndim == 2 and Y.shape[1] == 1:
            Y = Y.flatten()
        categories = self.categories
        if categories != 'auto':
            categories = [categories]  # OneHotEncoder expects a 2D array with features per column
        if package_is_at_least(sklearn, '1.2'):
            self.transformer = OneHotEncoder(categories=categories, sparse_output=False, drop='first')
        else:
            self.transformer = OneHotEncoder(categories=categories, sparse=False, drop='first')
        T = self.transformer.fit_transform(T.reshape(-1, 1))
        self._d_t = T.shape[1:]
        T = inverse_onehot(T)
        self.models = check_models(self.models, self._d_t[0] + 1)
        if self.cate_models is None:
            self.cate_models = [clone(model, safe=False) for model in self.models]
        else:
            self.cate_models = check_models(self.cate_models, self._d_t[0] + 1)
        self.propensity_models = []
        self.cate_treated_models = []
        self.cate_controls_models = []

        # Estimate response function
        for ind in range(self._d_t[0] + 1):
            self.models[ind].fit(X[T == ind], Y[T == ind])
        for ind in range(self._d_t[0]):
            self.cate_treated_models.append(clone(self.cate_models[ind + 1], safe=False))
            self.cate_controls_models.append(clone(self.cate_models[0], safe=False))
            self.propensity_models.append(clone(self.propensity_model, safe=False))
            # Replace:
            # imputed_effect_on_controls = self.models[ind + 1].predict(X[T == 0]) - Y[T == 0]
            # imputed_effect_on_treated = Y[T == ind + 1] - self.models[0].predict(X[T == ind + 1])
            imputed_effect_on_controls = self.models[ind + 1].predict_proba(X[T == 0])[:,1] - Y[T == 0]
            imputed_effect_on_treated = Y[T == ind + 1] - self.models[0].predict_proba(X[T == ind + 1])[:,1]
            self.cate_controls_models[ind].fit(X[T == 0], imputed_effect_on_controls)
            self.cate_treated_models[ind].fit(X[T == ind + 1], imputed_effect_on_treated)
            X_concat = np.concatenate((X[T == 0], X[T == ind + 1]), axis=0)
            T_concat = np.concatenate((T[T == 0], T[T == ind + 1]), axis=0)
            self.propensity_models[ind].fit(X_concat, T_concat)


def get_predictions_from_causal_model_single_treatment(dku_causal_model, multiframe):
    """
    :param DkuCausalModel dku_causal_model: a trained causal model able to predict CATE
    :param MultiFrame multiframe: MultiFrame with preprocessed data
    :return: tuple of ndarray's: only first element is returned: the 1d array CATE predictions
             TODO @causal: updade method or signature, see sc-104815
    """
    X = multiframe["TRAIN"].as_np_array()
    # TODO @causal: multi-treatment will yield non-binary treatment Series
    cate = dku_causal_model.predict_effect(X)
    # TODO @causal: get CATE decomposition into predictions with do(treatment) and do(control) whenever possible, see sc-104815
    prediction_do_control = None
    prediction_do_treatment = None
    # TODO @causal: reconcile predictions based on actual treatment value, see sc-104815
    if prediction_do_control is not None and prediction_do_treatment is not None:
        t_binary = multiframe["treatment"]
        prediction_actual = np.where(t_binary, prediction_do_treatment, prediction_do_control)
    else:
        prediction_actual = None

    return cate, prediction_actual, prediction_do_control, prediction_do_treatment


def get_predictions_from_causal_model_multi_treatment(dku_causal_model, multiframe):
    """
    :param DkuMultiTreatmentCausalModelsWrapper dku_causal_model: a trained causal model able to predict CATE
    :param MultiFrame multiframe: MultiFrame with preprocessed data
    :return: tuple of ndarray's: only first element is returned: the 1d array CATE predictions
             TODO @causal: updade method or signature, see sc-104815
    """
    X = multiframe["TRAIN"].as_np_array()
    all_cates = dku_causal_model.predict_all_effects_as_dict(X)

    return all_cates


class AbstractDkuCausalModel(object):

    def handles_multi_value_treatment(self):
        raise NotImplementedError()

    def fit(self, X, y, treatment, **params):
        raise NotImplementedError()


class DkuCausalModel(AbstractDkuCausalModel):
    """
    Wrapper class to unify the APIs of different classes of causal learning algorithms

    Note: assumes binary treatment

    Note: do not rename as this class is pickled during training and unpickled during scoring (recipe and API)
    """

    def __init__(self, model):
        self._model = model
        self._version = 1

    def handles_multi_value_treatment(self):
        return False

    def predict_effect(self, X):
        raise NotImplementedError()


class DkuMetaLearnerCausalModel(DkuCausalModel):
    """Note: do not rename as this class is pickled during training and unpickled during scoring (recipe and API)
    """
    def fit(self, X, y, treatment, **params):
        self._model.fit(y, treatment, X=X, **params)

    def predict_effect(self, X):
        return self._model.effect(X).ravel()

    def get_trained_clf_base_learner(self, meta_learner):
        if meta_learner == doctor_constants.S_LEARNER:
            return self._model.overall_model
        if meta_learner in {doctor_constants.T_LEARNER, doctor_constants.X_LEARNER}:
            return self._model.models[0]
        raise ValueError("Unknown meta-learner: {}".format(meta_learner))


class DkuCausalForestCausalModel(DkuCausalModel):
    """Note: do not rename as this class is pickled during training and unpickled during scoring (recipe and API)
    """
    def fit(self, X, y, treatment, **params):
        self._model.fit(X, treatment, y, **params)

    def predict_effect(self, X):
        return self._model.predict(X).ravel()

    def get_params(self):
        return self._model.get_params()


class DkuMultiTreatmentCausalModelsWrapper(AbstractDkuCausalModel):
    """
    Wrapper class to unify the APIs of different classes of causal learning algorithms

    Note: assumes multi-value treatment

    Note: do not rename as this class is pickled during training and unpickled during scoring (recipe and API)
    """

    def __init__(self, models, treatment_map):
        self._models = models
        self._treatment_map = treatment_map
        self._version = 1

    def handles_multi_value_treatment(self):
        return True

    def fit(self, X, y, treatments, **params):
        for val, i in self._treatment_map.items_except_control():
            mask = (treatments == 0) | (treatments == i)
            model = self._models[val]
            model.fit(X[mask], y[mask], (treatments[mask] != 0), **params)

    def predict_all_effects_as_dict(self, X):
        res = {}
        for val, i in self._treatment_map.items_except_control():
            model = self._models[val]
            res[val] = model.predict_effect(X)
        return res

    def predict_single_effect(self, X, treatment):
        model = self._models[treatment]
        return model.predict_effect(X)

    def get_trained_clf_base_learner(self, meta_learner):
        # Fetch any model from the (dict of treatment -> model) to get their base learner as they have all the same actual params
        # Used only for meta-learners
        causal_model_single_treatment = next(iter(self._models.values()))
        return causal_model_single_treatment.get_trained_clf_base_learner(meta_learner)

    def get_params(self):
        # Fetch any model from the (dict of treatment -> model) to get directly their params as they are all the same
        # Used only for causal forests
        causal_model_single_treatment = next(iter(self._models.values()))
        return causal_model_single_treatment.get_params()


class CausalLearning(object):
    def __init__(self, modeling_params):
        self.method = modeling_params["causal_method"]
        self.meta_learner = modeling_params["meta_learner"] if self.method == doctor_constants.META_LEARNER else None

    def get_causal_algorithm(self, algorithm_name, is_classifier):
        if self.method == doctor_constants.META_LEARNER:
            if is_classifier:
                if algorithm_name not in CLASSIFICATION_ALGORITHMS.keys():
                    raise Exception("Classification algorithm not available in Python: %s" % algorithm_name)
                base_learner = CLASSIFICATION_ALGORITHMS[algorithm_name]
            else:
                if algorithm_name not in REGRESSION_ALGORITHMS.keys():
                    raise Exception("Regression algorithm not available in Python: %s" % algorithm_name)
                base_learner = REGRESSION_ALGORITHMS[algorithm_name]
            return MetaLearnerCausalPredictionAlgorithm(base_learner, self.meta_learner)
        elif self.method == "CAUSAL_FOREST":
            return CausalForestAlgorithm()

        raise ValueError("Unsupported causal learning method: {}".format(self.method))

    def get_dku_causal_model(self, estimator, is_classifier, treatment_map=None):
        if treatment_map == None:
            # Binary treatment
            return self.get_dku_causal_model_single_treatment(estimator, is_classifier)
        else:
            # Multi-value treatment
            return self.get_dku_causal_model_multi_treatment(estimator, is_classifier, treatment_map)

    def get_dku_causal_model_multi_treatment(self, estimator, is_classifier, treatment_map):
        models = {}
        for t, i in treatment_map.items_except_control():
            causal_model = self.get_dku_causal_model_single_treatment(clone(estimator), is_classifier)
            models[t] = causal_model
        return DkuMultiTreatmentCausalModelsWrapper(models, treatment_map)

    def get_dku_causal_model_single_treatment(self, estimator, is_classifier):
        if self.method == doctor_constants.META_LEARNER:
            if self.meta_learner == doctor_constants.S_LEARNER:
                logger.info("Using a S-learner")
                econml_model = DkuSLearnerClassifier(overall_model=estimator) if is_classifier else SLearner(overall_model=estimator)
            elif self.meta_learner == doctor_constants.T_LEARNER:
                logger.info("Using a T-learner")
                econml_model = DkuTLearnerClassifier(models=estimator) if is_classifier else TLearner(models=estimator)
            elif self.meta_learner == doctor_constants.X_LEARNER:
                logger.info("Using a X-learner")
                # TODO: use a non-generic model (LinearRegression below) for CATE
                from sklearn.linear_model import LinearRegression
                econml_model = DkuXLearnerClassifier(models=estimator, cate_models=LinearRegression()) if is_classifier else XLearner(models=estimator, cate_models=LinearRegression())
            else:
                raise ValueError("Unknown meta-learner: {}".format(self.meta_learner))
            return DkuMetaLearnerCausalModel(econml_model)
        elif self.method == "CAUSAL_FOREST":
            full_params = estimator.get_params()
            subforest_size = full_params["subforest_size"]
            params = {
                "n_estimators": subforest_size * (full_params["n_estimators"] // subforest_size),  # n_estimators must be divisible by subforest_size
                "max_depth": full_params["max_depth"],
                "min_samples_leaf": full_params["min_samples_leaf"],
                "criterion": full_params["criterion"],
            }
            econml_model = CausalForest(random_state=1337,
                                        verbose=2,
                                        honest=full_params["honest"],
                                        n_jobs=full_params["n_jobs"])
            econml_model.set_params(**params)
            return DkuCausalForestCausalModel(econml_model)
        raise ValueError("Unknown causal learning method: {}".format(self.method))


class CausalForestAlgorithm(TabularPredictionAlgorithm):

    algorithm = "CAUSAL_FOREST"

    def model_from_params(self, input_hp_space, modeling_params, core_params):

        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "n_estimators": IntegerHyperparameterDimension,
                "max_depth": IntegerHyperparameterDimension,
                "min_samples_leaf": IntegerHyperparameterDimension,
                "criterion": CategoricalHyperparameterDimension
            },
        )
        estimator = CausalForest(random_state=1337,
                                 verbose=2,
                                 honest=input_hp_space["honest"],
                                 n_jobs=input_hp_space["n_jobs"])
        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, dku_causal_model, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "causal_forest_grid")
        params = dku_causal_model.get_params()
        logger.info("Causal Forest Params are %s " % params)
        ret["causal_forest_params"] = {
            "n_estimators": params["n_estimators"],
            "max_depth": params["max_depth"],
            "min_samples_leaf": params["min_samples_leaf"],
            "criterion": params["criterion"],
            "honest": params["honest"],
            "selection_mode": get_selection_mode(params["max_features"]),
        }

        if ret["causal_forest_params"]["selection_mode"] == "number":
            ret["causal_forest_params"]["max_features"] = params["max_features"]
        if ret["causal_forest_params"]["selection_mode"] == "prop":
            ret["causal_forest_params"]["max_feature_prop"] = params["max_features"]

        return amp

    def get_search_runner(self, core_params, modeling_params, column_labels=None, model_folder_context=None, treatment_map=None):
        logger.info(
            "Create causal model from params: {} for algorithm {}".format(
                modeling_params, self.algorithm)
        )

        input_hp_space = get_input_hyperparameter_space(modeling_params, self.algorithm)
        trainable_model = self.model_from_params(input_hp_space, modeling_params, core_params)
        trainable_model.set_column_labels(column_labels)

        hyperparameter_search_params = modeling_params.get("grid_search_params", {})
        trainable_model.hyperparameters_space.set_random_state(hyperparameter_search_params.get("seed", 0))

        search_settings = self.get_search_settings(hyperparameter_search_params, trainable_model)

        prediction_type = core_params[doctor_constants.PREDICTION_TYPE]
        cv = build_cv(modeling_params, column_labels, (prediction_type == doctor_constants.CAUSAL_BINARY_CLASSIFICATION))

        return CausalSearchRunner(
            trainable_model=trainable_model, cv=cv, search_settings=search_settings, causal_scorer=get_causal_scorer(modeling_params),
            causal_learning=CausalLearning(modeling_params), model_folder_context=model_folder_context, evaluation_metric=modeling_params["metrics"]["evaluationMetric"],
            propensity_settings=modeling_params["propensityModeling"], treatment_map=treatment_map
        )


class MetaLearnerCausalPredictionAlgorithm(TabularPredictionAlgorithm):

    def __init__(self, base_algorithm, meta_learner):
        self.base_algorithm = base_algorithm
        self.meta_learner = meta_learner

    def get_search_runner(self, core_params, modeling_params, column_labels=None, model_folder_context=None, treatment_map=None):
        logger.info(
            "Create causal model from params: {} for algorithm {}{}".format(
                modeling_params, self.base_algorithm.algorithm,
                "" if self.meta_learner is None else " | meta-learner {}".format(self.meta_learner)
            )
        )

        from dataiku.doctor.crossval.search_runner import CausalSearchRunner
        input_hp_space = get_input_hyperparameter_space(modeling_params, self.base_algorithm.algorithm)
        trainable_model = self.model_from_params(input_hp_space, modeling_params, core_params)
        trainable_model.set_column_labels(column_labels)

        hyperparameter_search_params = modeling_params.get("grid_search_params", {})
        trainable_model.hyperparameters_space.set_random_state(hyperparameter_search_params.get("seed", 0))

        search_settings = self.get_search_settings(hyperparameter_search_params, trainable_model)

        prediction_type = core_params[doctor_constants.PREDICTION_TYPE]
        cv = build_cv(modeling_params, column_labels, (prediction_type == doctor_constants.CAUSAL_BINARY_CLASSIFICATION))

        return CausalSearchRunner(
            trainable_model=trainable_model, cv=cv, search_settings=search_settings, causal_scorer=get_causal_scorer(modeling_params),
            causal_learning=CausalLearning(modeling_params), model_folder_context=model_folder_context, evaluation_metric=modeling_params["metrics"]["evaluationMetric"],
            propensity_settings=modeling_params["propensityModeling"], treatment_map=treatment_map
        )

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        return self.base_algorithm.model_from_params(input_hp_space, modeling_params, core_params)

    def actual_params(self, ret, dku_causal_model, fit_params):
        clf = dku_causal_model.get_trained_clf_base_learner(self.meta_learner)
        return self.base_algorithm.actual_params(ret, clf, fit_params)


def train_propensity_model(X, t, calibrate_proba, calibration_data_ratio):
    propensity_model_lr = LogisticRegression()
    if calibrate_proba:
        # Split the data to learn the probability calibration function on data not used to train the classifier.
        # Moreover, keep the same ratio of treated/control in both splits.
        splitter = StratifiedShuffleSplit(
            n_splits=1,
            test_size=calibration_data_ratio,
            random_state=1234
        )
        train_indices, calibrate_indices = next(splitter.split(X, t))  # NB: splitter produces a single split
        propensity_model_lr.fit(X[train_indices], t[train_indices])
        calibrated_propensity_model_lr = dku_calibrated_classifier_cv(propensity_model_lr, cv="prefit", method="isotonic")
        calibrated_propensity_model_lr.fit(X[calibrate_indices], t[calibrate_indices])
        return calibrated_propensity_model_lr
    else:
        propensity_model_lr.fit(X, t)
        return propensity_model_lr
