from dataiku.eda.types import Literal

import numpy as np
from statsmodels.stats.weightstats import DescrStatsW

from dataiku.doctor.utils import dku_nonaninf
from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NotEnoughDataError, UnknownObjectType
from dataiku.eda.types import AlternativeHypothesis, TTest1SampModel, TTest1SampResultModel


class TTest1Samp(UnivariateComputation):
    def __init__(self, column: str, hypothesized_mean: float, alternative: AlternativeHypothesis, confidence_level: float):
        super(TTest1Samp, self).__init__(column)
        self.hypothesized_mean = hypothesized_mean
        self.alternative = alternative
        self.confidence_level = confidence_level

    @staticmethod
    def get_type() -> Literal["ttest_1samp"]:
        return "ttest_1samp"

    @staticmethod
    def build(params: TTest1SampModel) -> 'TTest1Samp':
        return TTest1Samp(
            params['column'],
            params['hypothesizedMean'],
            params['alternative'],
            params['confidenceLevel']
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> TTest1SampResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) < 2 or np.all(np.equal(series, series[0])):
            raise NotEnoughDataError("T-test requires at least two different values")

        dku_to_sm_alternatives_mapping = {
            'TWO_SIDED': 'two-sided',
            'LOWER': 'smaller',
            'GREATER': 'larger'
        }

        if self.alternative not in dku_to_sm_alternatives_mapping:
            raise UnknownObjectType("Alternative must be one of {}".format(", ".join(dku_to_sm_alternatives_mapping.keys())))

        descr_stats = DescrStatsW(series)
        sm_alternative = dku_to_sm_alternatives_mapping[self.alternative]

        tstatistic, pvalue, dof = descr_stats.ttest_mean(self.hypothesized_mean, alternative=sm_alternative)
        ci_lower, ci_upper = descr_stats.tconfint_mean(alpha=1 - self.confidence_level, alternative=sm_alternative)

        return {
            "type": self.get_type(),
            "statistic": tstatistic,
            "pvalue": pvalue,
            "dof": dof,
            "mean": np.mean(series),
            "ciLower": dku_nonaninf(ci_lower),
            "ciUpper": dku_nonaninf(ci_upper)
        }
