from typing import Iterator, List, Optional
from dataiku.eda.types import Literal

import numpy as np
import pandas as pd

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import InvalidParams
from dataiku.eda.filtering.anum_filter import AnumFilter
from dataiku.eda.filtering.not_filter import NotFilter
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import AnumGroupingModel, AnumGroupingResultModel


class AnumGrouping(Grouping):
    def __init__(self, column: str, max_values: Optional[int], exclude_values: Optional[List[str]], regroup_others: bool):
        self.column = column
        self.max_values = max_values
        self.exclude_values = exclude_values
        self.regroup_others = regroup_others

    @staticmethod
    def get_type() -> Literal["anum"]:
        return "anum"

    def describe(self) -> str:
        max_val = "" if self.max_values is None else ", max=%s" % self.max_values
        return "Anum(%s%s)" % (self.column, max_val)

    @staticmethod
    def build(params: AnumGroupingModel) -> 'AnumGrouping':
        return AnumGrouping(
            params['column'],
            params.get('maxValues'),
            params.get('excludeValues'),
            params.get('groupOthers')
        )

    def compute_groups(self, idf: ImmutableDataFrame) -> 'AnumGroupingResult':
        if self.max_values is not None and self.max_values < 0:
            raise InvalidParams("max_values must be a positive integer")

        if self.exclude_values is None:
            filtered_idf = idf
        else:
            filtered_idf = NotFilter(AnumFilter(self.column, self.exclude_values)).apply(idf)

        series = filtered_idf.text_col(self.column)

        # Find the most frequent values ordered by (count desc, value asc)
        #
        # Multiple methods have been considered:
        # 1- pd.Series.value_counts
        # 2- pd.Categorical.value_counts
        # 3- np.unique (regular sort)
        # 4- np.partition? pd.Series.nlargest? (partial sort)
        #
        # Observations:
        # - (3) is slower than (1) for large series
        # - (2) is very fast ONLY if nb. categories << nb. of rows
        # - (4) would be the ideal candidate but it does not seem as easy as the others to implement
        #
        # => I've picked (1) because it performs well enough in general
        code_counts = pd.Series(series.codes, copy=False).value_counts(sort=False)

        # Order by (count desc, value asc) + limit
        value_idx = np.lexsort((code_counts.index, -code_counts.values))[:self.max_values]
        top_codes = code_counts.index[value_idx]

        # Merge all the lesser-used values together
        group_key = np.where(np.isin(series.codes, top_codes), series.codes, -1)

        # Generate row indices for each value
        value_to_row_indices = pd.Series(group_key, copy=False).groupby(group_key).indices

        # Produce a list of idfs along with a list of corresponding values
        idfs = []
        values = []
        for value_code in top_codes:
            idfs.append(filtered_idf[value_to_row_indices[value_code]])
            values.append(series.categories[value_code])

        # Add 'others' if requested and not empty
        has_all_values = -1 not in value_to_row_indices.keys()
        has_others = False

        if not has_all_values and self.regroup_others:
            idfs.append(filtered_idf[value_to_row_indices[-1]])
            has_others = True

        return AnumGroupingResult(self.column, idfs, values, has_others, has_all_values)


class AnumGroupingResult(GroupingResult):
    def __init__(self, column: str, idfs: List[ImmutableDataFrame], values: List[str], has_others: bool, has_all_values: bool):
        self.column = column
        self.idfs = idfs
        self.values = values
        self.has_others = has_others
        self.has_all_values = has_all_values

    def serialize(self) -> AnumGroupingResultModel:
        return {
            "type": AnumGrouping.get_type(),
            "column": self.column,
            "values": self.values,
            "hasOthers": self.has_others,
            "hasAllValues": self.has_all_values
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        for group_idf in self.idfs:
            yield group_idf
