from dataiku.eda.types import Literal

import scipy.stats as sps

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NotEnoughDataError
from dataiku.eda.types import ShapiroModel, ShapiroResultModel


class Shapiro(UnivariateComputation):
    @staticmethod
    def get_type() -> Literal["shapiro"]:
        return "shapiro"

    @staticmethod
    def build(params: ShapiroModel) -> 'Shapiro':
        return Shapiro(params['column'])

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> ShapiroResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) < 3:
            raise NotEnoughDataError("At least three values are required")

        warnings = []
        if len(series) > 5000:
            # Scipy will emit a warning in this case: it is important to bubble it up to the user.
            warnings.append("p-value may not be accurate for N > 5000")

        statistic, pvalue = sps.shapiro(series)

        return {
            "type": self.get_type(),
            "statistic": statistic,
            "pvalue": pvalue,
            "warnings": warnings
        }
