from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NotEnoughDataError, UnknownObjectType
from dataiku.eda.types import AlternativeHypothesis, ZTest1SampModel, ZTest1SampResultModel


class ZTest1Samp(UnivariateComputation):
    def __init__(self, column: str, hypothesized_mean: float, known_std_dev: float, alternative: AlternativeHypothesis, confidence_level: float):
        super(ZTest1Samp, self).__init__(column)
        self.hypothesized_mean = hypothesized_mean
        self.known_std_dev = known_std_dev
        self.alternative = alternative
        self.confidence_level = confidence_level

    @staticmethod
    def get_type() -> Literal["ztest_1samp"]:
        return "ztest_1samp"

    @staticmethod
    def build(params: ZTest1SampModel) -> 'ZTest1Samp':
        return ZTest1Samp(
            params['column'],
            params['hypothesizedMean'],
            params['knownStdDev'],
            params['alternative'],
            params['confidenceLevel']
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> ZTest1SampResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) < 1:
            raise NotEnoughDataError("Z-test requires at least one value")

        mean = series.mean()
        sem = self.known_std_dev / np.sqrt(len(series))
        zstatistic = np.sqrt(len(series)) * (mean - self.hypothesized_mean) / self.known_std_dev

        if self.alternative == "TWO_SIDED":
            # H1: true mean != hypothesized mean
            pvalue = sps.norm.sf(np.abs(zstatistic)) * 2

            zcrit = sps.norm.ppf(1 - (1 - self.confidence_level) / 2.0)
            ci_lower = mean - zcrit * sem
            ci_upper = mean + zcrit * sem
        elif self.alternative == "LOWER":
            # H1: true mean < hypothesized mean
            pvalue = sps.norm.cdf(zstatistic)

            zcrit = sps.norm.ppf(self.confidence_level)
            ci_lower = None
            ci_upper = mean + zcrit * sem
        elif self.alternative == "GREATER":
            # H1: true mean > hypothesized mean
            pvalue = sps.norm.sf(zstatistic)

            zcrit = sps.norm.ppf(1 - self.confidence_level)
            ci_lower = mean + zcrit * sem
            ci_upper = None
        else:
            raise UnknownObjectType("Alternative must be one of TWO_SIDED, LOWER, GREATER")

        return {
            "type": self.get_type(),
            "statistic": zstatistic,
            "pvalue": pvalue,
            "mean": np.mean(series),
            "ciLower": ci_lower,
            "ciUpper": ci_upper
        }
