from typing import List, Tuple
from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.computation import Computation
from dataiku.eda.computations.univariate.abstract_multi_sample import AbstractMultiSampleUnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.exceptions import NotEnoughDataError, DegenerateCaseError
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import OneWayAnovaModel, OneWayAnovaResultModel


class OneWayANOVA(AbstractMultiSampleUnivariateComputation):
    def __init__(self, column: str, grouping: Grouping, confidence_level: float):
        super(OneWayANOVA, self).__init__(column, grouping)
        self.confidence_level = confidence_level

    @staticmethod
    def get_type() -> Literal["one_way_anova"]:
        return "one_way_anova"

    @staticmethod
    def build(params: OneWayAnovaModel) -> 'OneWayANOVA':
        return OneWayANOVA(
            params['column'],
            Grouping.build(params["grouping"]),
            params['confidenceLevel'],
        )

    @staticmethod
    def _ci_mean(pooled_sd: float, dof: int, series: FloatVector, confidence_level: float) -> Tuple[float, float, float]:
        mean = np.mean(series)
        se = pooled_sd * np.sqrt(1.0 / len(series))

        tcrit = sps.t.ppf(1 - (1 - confidence_level) / 2.0, dof)
        ci_lower = mean - tcrit * se
        ci_upper = mean + tcrit * se

        return mean, ci_lower, ci_upper

    def _check_constant_samples_different_values(self, samples: List[FloatVector]) -> None:
        # checks the following edge case from Scipy documentation:
        # "all values in each group are identical, and there exist at least
        # two groups with different values"
        unique_values = set()
        for sample in samples:
            unique = np.unique(sample)
            if len(unique) != 1:
                return

            unique_values.add(unique[0])

        if len(unique_values) >= 2:
            raise DegenerateCaseError(
                ("All values of {} in each population are identical, and there "
                "exist at least two populations with different values").format(self.column)
            )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> OneWayAnovaResultModel:
        samples, grouped_idfs, computed_groups = self._compute_groups(idf)
        merged = self._check_disjoint_groups_and_not_degenerate_case(grouped_idfs)

        # Make sure we have enough data
        if len(merged) <= len(grouped_idfs):
            raise NotEnoughDataError("ANOVA requires (at least) one more value than the number of populations")

        # reject scipy edge cases
        self._check_constant_samples_different_values(samples)

        statistic, pvalue = sps.f_oneway(*samples)

        # Confidence intervals
        pooled_sd, dof = self._pooled_sd_and_dof(samples)

        means = []
        ci_lowers = []
        ci_uppers = []
        for sample in samples:
            mean, ci_lower, ci_upper = self._ci_mean(pooled_sd, dof, sample, self.confidence_level)
            means.append(mean)
            ci_lowers.append(ci_lower)
            ci_uppers.append(ci_upper)

        return {
            "type": self.get_type(),
            "statistic": statistic,
            "pvalue": pvalue,
            "groups": Computation._check_and_fix_result(computed_groups.serialize()),
            "dof": dof,
            "means": means,
            "ciLowers": ci_lowers,
            "ciUppers": ci_uppers
        }
