from typing import Dict, List, TypedDict

from solutions.graph.graph_db_instance_manager import AbstractDbInstance


class NodePropertySchema(TypedDict):
    name: str
    type: str
    is_primary_key: bool


class EdgePropertySchema(TypedDict):
    name: str
    type: str


class NodeTableSchema(TypedDict):
    label: str
    properties: List[NodePropertySchema]


class EdgeTableSchema(TypedDict):
    label: str
    from_group_label: str
    to_group_label: str
    properties: List[EdgePropertySchema]


class GraphSchema(TypedDict):
    node_groups: Dict[str, NodeTableSchema]
    edge_groups: Dict[str, EdgeTableSchema]


def extract_schema(db_instance: AbstractDbInstance) -> GraphSchema:
    with db_instance.get_new_conn() as conn_context_manager, conn_context_manager.open() as conn:
        nodes: Dict[str, NodeTableSchema] = {}
        edges: Dict[str, EdgeTableSchema] = {}

        all_groups_info_df = conn.execute("CALL show_tables() RETURN type, name;").get_as_df()  # type: ignore
        for _, group_row in all_groups_info_df.iterrows():
            type = group_row["type"]  # type: ignore
            group_label: str = group_row["name"]

            if type == "NODE":
                props_info_df = conn.execute(f'CALL TABLE_INFO("{group_label}") RETURN *;').get_as_df()  # type: ignore
                nodes[group_label] = {
                    "label": group_label,
                    "properties": [
                        {
                            "name": prop["name"],  # type: ignore
                            "type": prop["type"],  # type: ignore
                            "is_primary_key": prop["primary key"],  # type: ignore
                        }
                        for _, prop in props_info_df.iterrows()
                    ],
                }
            elif type == "REL":
                from_to_info_df = conn.execute(f'CALL SHOW_CONNECTION("{group_label}") RETURN *;').get_as_df()  # type: ignore
                from_group_label: str = ""
                to_group_label: str = ""
                for _, r in from_to_info_df.iterrows():
                    from_group_label = r["source table name"]
                    to_group_label = r["destination table name"]

                props_info_df = conn.execute(f'CALL TABLE_INFO("{group_label}") RETURN *;').get_as_df()  # type: ignore
                edges[group_label] = {
                    "label": group_label,
                    "from_group_label": from_group_label,
                    "to_group_label": to_group_label,
                    "properties": [
                        {
                            "name": prop["name"],  # type: ignore
                            "type": prop["type"],  # type: ignore
                        }
                        for _, prop in props_info_df.iterrows()
                    ],
                }
            else:
                raise Exception(f'Unexpected type encountered "{type}" when extracting Kuzu graph schema.')

    return {"node_groups": nodes, "edge_groups": edges}
