from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.computations.univariate.abstract_pairwise import AbstractPairwiseUnivariateComputation
from dataiku.eda.exceptions import DegenerateCaseError
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.stats.multitest import multitest_correction
from dataiku.eda.types import PairwiseMoodTestModel, PairwiseMoodTestResultModel


# Pairwise unpaired Mood test
class PairwiseMoodTest(AbstractPairwiseUnivariateComputation):
    @staticmethod
    def get_type() -> Literal["pairwise_mood_test"]:
        return "pairwise_mood_test"

    @staticmethod
    def build(params: PairwiseMoodTestModel) -> 'PairwiseMoodTest':
        return PairwiseMoodTest(
            params['column'],
            Grouping.build(params["grouping"]),
            params['oneVsAll'],
            params['adjustmentMethod']
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> PairwiseMoodTestResultModel:
        samples, grouped_idfs, _ = self._compute_groups(idf)

        pvalues = []
        statistics = []

        for series_i, series_j in self._iterate_pairs(samples, grouped_idfs):
            # Make sure values are not all equal (degenerate case)
            some_value = series_i[0]
            if np.all(series_i == some_value) and np.all(series_j == some_value):
                raise DegenerateCaseError("All values of {} are equal for at least one pair of populations".format(self.column))

            statistic, pvalue, _, _ = sps.median_test(series_i, series_j, ties='ignore')

            pvalues.append(pvalue)
            statistics.append(statistic)

        adjusted_pvalues = multitest_correction(pvalues, self.adjustment_method)

        return {
            "type": self.get_type(),
            "statistics": statistics,
            "pvalues": pvalues,
            "adjustedPvalues": adjusted_pvalues,
        }
