import numpy as np
import pandas as pd

from typing import Tuple, List

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.exceptions import DegenerateCaseError
from dataiku.eda.exceptions import GroupsAreNotDisjoint
from dataiku.eda.exceptions import NotEnoughDataError
from dataiku.eda.exceptions import NotEnoughGroupsError
from dataiku.eda.grouping.grouping import Grouping, GroupingResult


class AbstractMultiSampleUnivariateComputation(UnivariateComputation):

    def __init__(self, column: str, grouping: Grouping):
        super(AbstractMultiSampleUnivariateComputation, self).__init__(column)
        self.grouping = grouping

    @staticmethod
    def _pooled_sd_and_dof(samples: List[FloatVector]) -> Tuple[float, int]:
        # Pooled standard deviation
        samples_vars = [np.var(sample, ddof=1) if len(sample) > 1 else 0 for sample in samples]
        sum_of_weighted_vars = sum((len(sample) - 1) * var for sample, var in zip(samples, samples_vars))
        total_dof = sum(len(sample) for sample in samples) - len(samples)
        return np.sqrt(sum_of_weighted_vars / total_dof), total_dof

    @staticmethod
    def _welch_se_and_dof(sample1: FloatVector, sample2: FloatVector) -> Tuple[float, float]:
        l1 = len(sample1)
        l2 = len(sample2)

        s1 = np.var(sample1, ddof=1) / l1 if l1 > 1 else 0
        s2 = np.var(sample2, ddof=1) / l2 if l2 > 1 else 0

        se_dif = np.sqrt(s1 + s2)

        if (l1 <= 1 or l2 <= 1):
            dof = 0
        else:
            # Welch–Satterthwaite approximation
            dof = (s1 + s2)**2 / (s1**2 / (l1 - 1) + s2**2 / (l2 - 1))
        return se_dif, dof

    def _compute_groups(self, idf: ImmutableDataFrame) -> Tuple[List[FloatVector], List[ImmutableDataFrame], GroupingResult]:
        computed_groups = self.grouping.compute_groups(idf)
        grouped_idfs = [gidf[np.isfinite(gidf.float_col(self.column))] for gidf in computed_groups.iter_groups()]

        if len(grouped_idfs) < 2:
            raise NotEnoughGroupsError("At least two independent populations are required")

        for group_idf in grouped_idfs:
            if len(group_idf) == 0:
                raise NotEnoughDataError("At least one population is empty or does not have any value for {}".format(self.column))

        samples = [group_idf.float_col(self.column) for group_idf in grouped_idfs]

        return samples, grouped_idfs, computed_groups

    def _check_disjoint_groups_and_not_degenerate_case(self, grouped_idfs: List[ImmutableDataFrame]) -> ImmutableDataFrame:
        # Check that the groups are independent:
        # - Sample independence is assumed by ANOVA / Median Mood.
        #   If they are not, ANOVA / Median Mood can still be computed
        #   but the result is worthless from a statistical point of view
        # - If the groups are not disjoint then the assumption is clearly violated
        merged = grouped_idfs[0]
        for group_idf in grouped_idfs:
            merged |= group_idf

        summed_size = sum(len(group_idf) for group_idf in grouped_idfs)
        merged_size = len(merged)

        if summed_size != merged_size:
            raise GroupsAreNotDisjoint()

        # Make sure values are not all equal (degenerate case)
        if np.all(merged.float_col(self.column) == merged.float_col(self.column)[0]):
            raise DegenerateCaseError("At least two different values of {} are required but all values are equal".format(self.column))

        return merged
