from enum import Enum
from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.univariate.abstract_multi_sample import AbstractMultiSampleUnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.exceptions import DegenerateCaseError
from dataiku.eda.exceptions import GroupsAreNotDisjoint
from dataiku.eda.exceptions import UnknownObjectType
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.stats.multitest import multitest_correction
from dataiku.eda.types import AlternativeHypothesis, PairwiseTTestModel, PairwiseTTestResultModel, PValueAdjustmentMethod, VarianceAssumption


class VarianceAssumptionEnum(str, Enum):
    EQUAL = "EQUAL"
    UNEQUAL = "UNEQUAL"


class UnknownVarianceAssumptionException(UnknownObjectType):
    DEFAULT_MESSAGE = "Variance assumption must be one of {}".format(", ".join(va.value for va in VarianceAssumptionEnum))


# Pairwise unpaired t-test
class PairwiseTTest(AbstractMultiSampleUnivariateComputation):
    def __init__(self, column: str, grouping: Grouping, one_vs_all: bool, variance_assumption: VarianceAssumption, adjustment_method: PValueAdjustmentMethod, alternative: AlternativeHypothesis, confidence_level: float):
        super(PairwiseTTest, self).__init__(column, grouping)
        self.adjustment_method: PValueAdjustmentMethod = adjustment_method
        self.alternative: AlternativeHypothesis = alternative
        self.confidence_level = confidence_level
        self.one_vs_all = one_vs_all
        self.variance_assumption: VarianceAssumption = variance_assumption

    @staticmethod
    def get_type() -> Literal["pairwise_ttest"]:
        return "pairwise_ttest"

    @staticmethod
    def build(params: PairwiseTTestModel) -> 'PairwiseTTest':
        return PairwiseTTest(
            params['column'],
            Grouping.build(params["grouping"]),
            params['oneVsAll'],
            params['varianceAssumption'],
            params['adjustmentMethod'],
            params['alternative'],
            params['confidenceLevel']
        )

    @staticmethod
    def _ttest_impl(se_dif: float, dof: float, series1: FloatVector, series2: FloatVector, alternative: AlternativeHypothesis, confidence_level: float):
        mean_dif = np.mean(series1) - np.mean(series2)
        statistic = np.divide(mean_dif, se_dif)

        if alternative == "TWO_SIDED":
            # H1: series 1 mean != series 2 mean
            pvalue = sps.t.sf(np.abs(statistic), dof) * 2

            tcrit = sps.t.ppf(1 - (1 - confidence_level) / 2.0, dof)
            ci_lower = mean_dif - tcrit * se_dif
            ci_upper = mean_dif + tcrit * se_dif
        elif alternative == "LOWER":
            # H1: series 1 mean < series 2 mean
            pvalue = sps.t.cdf(statistic, dof)

            tcrit = sps.t.ppf(confidence_level, dof)
            ci_lower = None
            ci_upper = mean_dif + tcrit * se_dif
        elif alternative == "GREATER":
            # H1: series 1 mean > series 2 mean
            pvalue = sps.t.sf(statistic, dof)

            tcrit = sps.t.ppf(1 - confidence_level, dof)
            ci_upper = None
            ci_lower = mean_dif + tcrit * se_dif
        else:
            raise UnknownObjectType("Alternative must be one of TWO_SIDED, LOWER, GREATER")

        return statistic, pvalue, mean_dif, ci_lower, ci_upper

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> PairwiseTTestResultModel:
        samples, grouped_idfs, _ = self._compute_groups(idf)

        if self.variance_assumption == VarianceAssumptionEnum.EQUAL:
            if np.all([np.all(np.equal(series, series[0])) for series in samples]):
                raise DegenerateCaseError("At least one population must have at least two different values for {}".format(self.column))

            pooled_sd, dof = self._pooled_sd_and_dof(samples)
        elif self.variance_assumption == VarianceAssumptionEnum.UNEQUAL:
            if np.any([np.all(np.equal(series, series[0])) for series in samples]):
                raise DegenerateCaseError("All populations must have at least two different values for {}".format(self.column))
        else:
            raise UnknownVarianceAssumptionException()

        statistics = []
        pvalues = []
        mean_diffs = []
        ci_lowers = []
        ci_uppers = []
        dofs = []

        for i, idf_i in enumerate(grouped_idfs):
            # if 1 vs ALL, only compute comparisons with the reference group (which is the first group)
            if self.one_vs_all and i > 0:
                break

            for j, idf_j in enumerate(grouped_idfs):
                if i >= j:
                    continue

                if len(idf_i & idf_j) > 0:
                    # We should never end up here, this is likely a programming mistake from the caller of EDA compute
                    raise GroupsAreNotDisjoint()

                if self.variance_assumption == VarianceAssumptionEnum.EQUAL:
                    se_dif = pooled_sd * np.sqrt(1.0 / len(samples[j]) + 1.0 / len(samples[i]))
                elif self.variance_assumption == VarianceAssumptionEnum.UNEQUAL:
                    se_dif, dof = self._welch_se_and_dof(samples[j], samples[i])
                else:
                    # unreachable, should have already been raised earlier
                    raise UnknownVarianceAssumptionException()

                statistic, pvalue, mean_diff, ci_lower, ci_upper = self._ttest_impl(se_dif, dof, samples[j], samples[i], self.alternative, self.confidence_level)

                statistics.append(statistic)
                pvalues.append(pvalue)
                mean_diffs.append(mean_diff)
                ci_lowers.append(ci_lower)
                ci_uppers.append(ci_upper)
                dofs.append(dof)

        adjusted_pvalues = multitest_correction(pvalues, self.adjustment_method)

        # !! warning: pvalues are adjusted for multiple testing, but confidence intervals bounds are not
        # do not use the confidence intervals if the adjustment method is not None
        return {
            "type": self.get_type(),
            "statistics": statistics,
            "pvalues": pvalues,
            "adjustedPvalues": adjusted_pvalues,
            "dofs": dofs,
            "meanDiffs": mean_diffs,
            "ciLowers": ci_lowers,
            "ciUppers": ci_uppers
        }
