from dataiku.eda.types import Literal

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.computations.univariate.test_distribution import TestDistribution
from dataiku.eda.distributions.distribution import Distribution
from dataiku.eda.exceptions import NoDataError
from dataiku.eda.types import FitDistributionModel, FitDistributionResultModel


class FitDistribution(UnivariateComputation):
    def __init__(self, column: str, distribution: Distribution, test: bool):
        super(FitDistribution, self).__init__(column)
        self.distribution = distribution
        self.test = test

    @staticmethod
    def get_type() -> Literal["fit_distribution"]:
        return "fit_distribution"

    def describe(self) -> str:
        return "FitDistribution(%s)" % self.distribution.__class__.__name__

    @staticmethod
    def build(params: FitDistributionModel) -> 'FitDistribution':
        return FitDistribution(
            params['column'],
            Distribution.build(params['distribution']),
            params['test']
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> FitDistributionResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) == 0:
            raise NoDataError()

        model = self.distribution.fit(series)
        output: FitDistributionResultModel = {
            "type": FitDistribution.get_type(),
            "fit": model.serialize()
        }

        if self.test:
            output["test"] = TestDistribution(self.column, model).apply(idf, ctx)

        return output
