from typing import Iterator
from dataiku.eda.types import Literal

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.filtering.filter import Filter
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import SubsetGroupingModel, SubsetGroupingResultModel


class SubsetGrouping(Grouping):
    def __init__(self, subset_filter: Filter):
        self.filter = subset_filter

    @staticmethod
    def get_type() -> Literal["subset"]:
        return "subset"

    @staticmethod
    def build(params: SubsetGroupingModel) -> 'SubsetGrouping':
        return SubsetGrouping(Filter.build(params["filter"]))

    def describe(self) -> str:
        return "Subset"

    def compute_groups(self, idf: ImmutableDataFrame) -> 'SubsetGroupingResult':
        return SubsetGroupingResult(self.filter, self.filter.apply(idf))


class SubsetGroupingResult(GroupingResult):
    def __init__(self, subset_filter: Filter, idf: ImmutableDataFrame):
        self.filter = subset_filter
        self.idf = idf

    def serialize(self) -> SubsetGroupingResultModel:
        return {"type": SubsetGrouping.get_type(), "filter": self.filter.serialize()}

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        yield self.idf
