from random import Random
from typing import Iterator
from dataiku.eda.types import Literal

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import SubsampledGroupingModel, SubsampledGroupingResultModel


class SubsampledGrouping(Grouping):
    def __init__(self, max_rows: int, seed: int):
        self.max_rows = max_rows
        self.seed = seed

    @staticmethod
    def get_type() -> Literal["subsampled"]:
        return "subsampled"

    def describe(self) -> str:
        return "Subsampled(maxRows=%s)" % self.max_rows

    @staticmethod
    def build(params: SubsampledGroupingModel) -> 'SubsampledGrouping':
        return SubsampledGrouping(
            params['maxRows'],
            params['seed']
        )

    def compute_groups(self, idf: ImmutableDataFrame) -> 'SubsampledGroupingResult':
        if len(idf) <= self.max_rows:
            return SubsampledGroupingResult(idf)

        indices = Random(self.seed).sample(range(0, len(idf)), self.max_rows)
        return SubsampledGroupingResult(idf[indices])


class SubsampledGroupingResult(GroupingResult):
    def __init__(self, idf: ImmutableDataFrame):
        self.idf = idf

    def serialize(self) -> SubsampledGroupingResultModel:
        return {
            "type": SubsampledGrouping.get_type()
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        yield self.idf
