import logging

import numpy as np
import scipy.stats
from scipy.sparse import diags

from dataiku.doctor.utils import dku_nonaninf

logger = logging.getLogger(__name__)


class Denormalizer(object):
    """
    Post-processing on the coefficients of a linear model.
    Scales back coefficients, intercepts and std thereof to maintain homogeneity with the original variable.
    """
    def __init__(self, rescalers):
        self.scalings = {rescaler.in_col: rescaler for rescaler in rescalers}

    def denormalize_feature_value(self, feature_name, feature_value):
        if feature_name in self.scalings:
            scaler = self.scalings[feature_name]
            inv_scale = scaler.inv_scale if scaler.inv_scale != 0.0 else 1.0
            return (feature_value / inv_scale) + scaler.shift
        else:
            return feature_value

    def denormalize_coef(self, feature_name, coef_value):
        if feature_name in self.scalings:
            scaler = self.scalings[feature_name]
            inv_scale = scaler.inv_scale if scaler.inv_scale != 0.0 else 1.0
            return coef_value * inv_scale
        else:
            return coef_value

    def denormalize_intercept(self, intercept_value, feature_names, coef_values):
        denormalized_intercept_value = intercept_value
        for feature_name, coef_value in zip(feature_names, coef_values):
            scaler = self.scalings.get(feature_name, None)
            if scaler is None:
                # whenever no rescaling (e.g. for dummy features, nothing to add
                continue
            else:
                inv_scale = scaler.inv_scale if scaler.inv_scale != 0.0 else 1.0
                shift = scaler.shift
                denormalized_intercept_value -= coef_value * shift * inv_scale
        return denormalized_intercept_value

    def denormalize_intercept_stderr(self, intercept_stderr, feature_names, coef_stderr_values):
        # NB: underlying zero-correlation between coefficients error hypothesis
        squared_res = intercept_stderr ** 2
        for feature_name, coef_stderr_value in zip(feature_names, coef_stderr_values):
            scaler = self.scalings.get(feature_name, None)
            if scaler is None:
                # whenever no rescaling (e.g. for dummy features, nothing to add
                continue
            else:
                squared_res += (scaler.shift * scaler.inv_scale * coef_stderr_value)**2
        return np.sqrt(squared_res)


def compute_coefs_if_available(clf, train_X, prepared_X, train_y, rescalers, perf, prediction_is_classification):
    """
    Check if the estimator's coef_ and intercept_ attributes are available and of the right type
    for linear coefficients computation. If so, add the resulting linear coefficients as "lmCoefficients"
    to the model's intrinsic performance
    :param clf: linear model
    :param train_X: the multiframe
    :param prepared_X: input array
    :param train_y: expectations array
    :param list rescalers: preprocessing steps to rescale the values
    :param dict perf: dictionary in which we must set the lmCoefficients attribute
    :param bool prediction_is_classification: false if prediction type is regression
    """
    features = train_X.columns()
    coefs = _extract_coef_from_clf(clf, features, prediction_is_classification)
    if coefs is None:
        return
    intercept = _extract_intercept_from_clf(clf, prediction_is_classification)
    if intercept is None:
        return

    logger.info("Computing regression coefs")
    perf_coefs = {"variables": [], "coefs": []}
    logger.info("FEATURES %s CLF COEF %s" % (len(features), len(coefs)))
    logger.info("CLF Intercept: %s" % intercept)

    # compute base coefs
    (stderr, tstat, pvalue, istderr, itstat, ipvalue) = _compute_lm_significance(clf, coefs, intercept, prepared_X, train_y, prediction_is_classification)
    _add_coefs_to_perf(perf_coefs, intercept, coefs, features, stderr, tstat, pvalue, istderr, itstat, ipvalue)
    # compute renormalized coefs
    _add_rescaled_coefs_to_perf(perf_coefs, prepared_X, rescalers)
    perf["lmCoefficients"] = perf_coefs


def _extract_intercept_from_clf(clf, prediction_is_classification):
    base_message = "Not computing linear coefficients because "
    if not hasattr(clf, "intercept_"):
        logger.info(base_message + "attribute `intercept_` is not present")
        return None
    try:
        intercept = clf.intercept_
    except AttributeError:
        logger.info(base_message + "`intercept_` is present, but could not be retrieved")
        return None

    if prediction_is_classification:
        if not (isinstance(intercept, list) or isinstance(intercept, np.ndarray)):
            logger.info(base_message + "`intercept_` has wrong format: '{}'".format(type(intercept)))
            return None
        if isinstance(intercept, list):
            intercept = np.array(intercept)
        if intercept.shape[0] != 1:
            logger.info(base_message + "`intercept_` has the wrong shape")
            return None

    try:
        # intercept is either a float (regression) or a 1d array with only one element (classification)
        intercept = float(intercept)
    except:
        logger.info(base_message + "`intercept_` cannot be converted to a number")
        return None
    return intercept


def _extract_coef_from_clf(clf, features, prediction_is_classification):
    base_message = "Not computing linear coefficients because "
    if not hasattr(clf, "coef_"):
        logger.info(base_message + "attribute `coef_` is not present")
        return None
    try:
        coefs = clf.coef_
    except AttributeError:
        logger.info(base_message + "`coef_` is present, but could not be retrieved")
        return None

    if not (isinstance(coefs, list) or isinstance(coefs, np.ndarray)):
        logger.info(base_message + "`coef_` has wrong format: '{}'".format(type(coefs)))
        return None

    if isinstance(coefs, list):
        coefs = np.array(coefs)

    if prediction_is_classification:
        if coefs.ndim != 2 or coefs.shape[0] != 1:
            logger.info(base_message + "`coef_` or `intercept_` have the wrong shape")
            return None
        coefs = coefs[0]

    if len(features) != len(coefs):
        logger.info(base_message + "misalignment between features and coefficients")
        return None
    return coefs


def _compute_lm_significance(clf, coefs, intercept, prepared_X, train_y, prediction_is_classification):
    """Returns (t_test, p_val)"""

    # The t_stat(coefX) is:  coefX / stddev(coefX)
    # The stddev of all coefficients is given by:
    #  a. sigma^2 * (X^T*X)^-1
    #     for regressions
    #     where sigma^2 = sum(square_errors) / degrees_of_freedom
    # b.  (X^T * diag(probas*(1-probas)) * X)^-1
    #     for binary classifications
    #     see e.g. https://stats200.stanford.edu/Lecture26.pdf
    # NB: These estimates of the variance of coefficients are assuming homoscedasticity of the data.
    #     In the heteroscedastic case:
    #       - the estimations of the variances are biased for both linear and logistic regression
    #       - coefficients of logistic regression are biased (still unbiased for linear regression)
    # => TODO: use an estimator of the variance of coefficients that is robust to heteroscedasticity (e.g. "Object-oriented Computation of Sandwich Estimators")
    default_value = (None, None, None, None, None, None)

    X = prepared_X
    y = train_y

    # We refuse to invert too big matrices (we have to invert a coef*coef matrix)
    if X.shape[1] > 1000 or X.__class__.__name__ == "csr_matrix":
        logger.warning("Information matrix too large to compute linear model significance: {} features".format(X.shape[1]))
        return default_value

    d_freedom = float(X.shape[0]-X.shape[1]-1)
    if prediction_is_classification:
        if not hasattr(clf, "predict_proba"):
            logger.warning("Cannot get probabilities from classifier. No `predict_proba` method")
            return default_value
        predicted = np.matrix(clf.predict_proba(X)[:, 1]).T
    else:
        predicted = np.matrix(clf.predict(X)).T

    # Change X and Y into numpy matrices for easier operations, and add constant column to X
    X = np.hstack((np.ones((X.shape[0], 1)), np.matrix(X)))
    y = np.matrix(y).T

    coefs_with_intercept = np.hstack((intercept, coefs))
    logger.info("Coefs: %s" % coefs_with_intercept)

    # Sample variance (sigma^2 = sum(square_errors) / d_freedom )
    sigmasq = np.sum(np.square(predicted - y)) / d_freedom
    logger.info("Sample variance: %s" % sigmasq)

    if prediction_is_classification:
        diag = diags([predicted.A[:, 0] * (1-predicted.A[:, 0])], [0])
        logger.info("Diagonal in information_matrix computation: %s" % str(diag))
        information_matrix = X.T * diag * X
    else:
        information_matrix = X.T * X

    # Quick check that we can inverse information_matrix
    import scipy as sc
    if np.isclose(sc.linalg.det(information_matrix), 0.0):
        logger.info("Singular variance matrix: information matrix not invertible.")
        return default_value

    logger.info("information matrix (X^T*X or X^T*Diag*X) shape: %s" % (information_matrix.shape,))

    # Compute the covariance matrix
    if prediction_is_classification:
        cvm = information_matrix.I
    else:
        cvm = sigmasq * information_matrix.I

    # Standard errors for the coefficients: the sqrt of the diagonal elements of the covariance matrix.
    logger.info("Coefficient standard errors: %s" % np.sqrt(cvm.diagonal()))

    se = np.sqrt(cvm.diagonal().A[0,1:]) # Remove the constant

    # T statistic for each beta. (coef / coef_stddev)
    base_t_stat = coefs/se

    # P-value for each beta. This is a two sided t-test, since the betas can be
    # positive or negative.
    import scipy.stats
    betas_p_value = 1 - scipy.stats.t.cdf(abs(base_t_stat), d_freedom)

    # Same for intercept
    ise = np.sqrt(cvm.diagonal().A[0,0])
    itstat = intercept/ise
    ipval = 1 - scipy.stats.t.cdf(abs(itstat), d_freedom)

    if np.isnan(betas_p_value).any():
        logger.info("NaN found in p-values")
        return default_value

    return se, base_t_stat, betas_p_value, ise, itstat, ipval


def _add_coefs_to_perf(perf_coefs, intercept, coefs, features, stderrs, tstats, pvalues, istderr, itstat, ipvalue):
    if tstats is not None:
        perf_coefs["stderr"] = []
        perf_coefs["rescaledStderr"] = []
        perf_coefs["tstat"] = []
        perf_coefs["pvalue"] = []

        for variable, coef, stderr, tstat, pvalue in zip(features, coefs, stderrs, tstats, pvalues):
            if coef != 0.0:
                logger.info("Variable=%s coef=%s" % (variable, coef))
                perf_coefs["variables"].append(variable)
                perf_coefs["coefs"].append(coef)
                perf_coefs["stderr"].append(stderr)
                perf_coefs["tstat"].append(tstat)
                perf_coefs["pvalue"].append(pvalue)
    else:
        for variable, coef in zip(features, coefs):
            if coef != 0.0:
                logger.info("variable=%s coef=%s" % (variable, coef))
                perf_coefs["variables"].append(variable)
                perf_coefs["coefs"].append(coef)

    perf_coefs["interceptStderr"] = istderr
    perf_coefs["interceptTstat"] = itstat
    perf_coefs["interceptPvalue"] = ipvalue
    perf_coefs["interceptCoef"] = intercept


def _add_rescaled_coefs_to_perf(perf_coefs, prepared_X, rescalers):
    denorm = Denormalizer(rescalers)
    perf_coefs["rescaledCoefs"] = [denorm.denormalize_coef(name, value) for name, value in zip(perf_coefs["variables"], perf_coefs["coefs"])]
    perf_coefs["rescaledInterceptCoef"] = denorm.denormalize_intercept(perf_coefs["interceptCoef"], perf_coefs["variables"], perf_coefs["coefs"])
    if "stderr" in perf_coefs:
        perf_coefs["rescaledStderr"] = [denorm.denormalize_coef(name, value) for name, value in zip(perf_coefs["variables"], perf_coefs["stderr"])]
        perf_coefs["rescaledInterceptStderr"] = denorm.denormalize_intercept_stderr(perf_coefs["interceptCoef"], perf_coefs["variables"], perf_coefs["stderr"])
        perf_coefs["rescaledInterceptTstat"] = dku_nonaninf(perf_coefs["rescaledInterceptCoef"] / perf_coefs["rescaledInterceptStderr"])

        if perf_coefs["rescaledInterceptTstat"] is not None:
            df = float(prepared_X.shape[0]-prepared_X.shape[1]-1)
            perf_coefs["rescaledInterceptPvalue"] = 1 - scipy.stats.t.cdf(abs(perf_coefs["rescaledInterceptTstat"]), df)
        else:
            perf_coefs["rescaledInterceptPvalue"] = None
