from __future__ import annotations

from typing import List, TypedDict

import kuzu
from pandas import DataFrame

from solutions.graph.kuzu.schema.schema_extractor import NodeTableSchema


class AdjacentNodeGroupInfo(TypedDict):
    label: str
    count: int


def get_adjacent_node_groups_info(
    conn: kuzu.Connection, node_schema: NodeTableSchema, pk_value: str | int
) -> List[AdjacentNodeGroupInfo]:
    pk_props = [prop for prop in node_schema["properties"] if prop["is_primary_key"]]
    assert len(pk_props) == 1, f"Expected one primary key property, got {pk_props}."
    pk_prop = pk_props[0]

    query_result = conn.execute(
        f"""
        MATCH (n:`{node_schema['label']}` {{`{pk_prop["name"]}`: $pk_value}})-[]-(t) 
        RETURN LABEL(t) AS targetLabel, COUNT(DISTINCT t) AS count
        """,
        parameters={"pk_value": pk_value},
    )
    df: DataFrame = query_result.get_as_df()  # type: ignore

    return [{"label": row["targetLabel"], "count": row["count"]} for _, row in df.iterrows()]
