from dataiku.eda.types import Literal

from dataiku.eda.computations.computation import Computation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import GroupedComputationModel, GroupedComputationResultModel


class GroupedComputation(Computation):
    def __init__(self, computation: Computation, grouping: Grouping):
        self.computation = computation
        self.grouping = grouping

    @staticmethod
    def get_type() -> Literal["grouped"]:
        return "grouped"

    @staticmethod
    def build(params: GroupedComputationModel) -> 'GroupedComputation':
        return GroupedComputation(
            Computation.build(params['computation']),
            Grouping.build(params['grouping'])
        )

    @staticmethod
    def _require_result_checking() -> bool:
        return False

    def describe(self) -> str:
        return "GroupBy(%s)" % self.grouping.describe()

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> GroupedComputationResultModel:
        with ctx.sub("ComputeGroups"):
            computed_groups = self.grouping.compute_groups(idf)

        results = []
        for index, group_idf in enumerate(computed_groups.iter_groups()):
            with ctx.sub("%s" % index, brackets=True) as sub:
                results.append(self.computation.apply_safe(group_idf, sub))

        return {
            "type": GroupedComputation.get_type(),
            "groups": Computation._check_and_fix_result(computed_groups.serialize()),
            "results": results  # Results were already checked when apply_safe() was called on the inner computation
        }
