from typing import Dict, List

import dataiku
from common.backend.utils.dataiku_api import dataiku_api
from dataiku.sql import SelectQuery


def replace_tables_in_ast(ast: SelectQuery) -> List[str]:
    tables_used = set()
    valid_aliases, cte_aliases = set(), set()

    dataset_names = dataiku_api.webapp_config.get("sql_retrieval_table_list", [])
    if not dataset_names:
        raise ValueError("No dataset names provided in the webapp config")
    def _collect_valid_aliases(node: Dict):
        assert node["type"] == "TABLE"

        if "alias" in node:
            valid_aliases.add(node["alias"])

    def _collect_aliases_rec_list(node: List, is_cte: bool = False) -> None:
        for item in node:
            if is_cte and (cte_alias := item.get("alias")):
                cte_aliases.add(cte_alias)
            if isinstance(item, dict):
                _collect_aliases_rec_dict(item)
            elif isinstance(item, list):
                # We don't use is_cte here because we can't have recursive CTEs
                # check dip src/main/java/com/dataiku/dip/sql/queries/QuerySQLWriter.java > writeCTEs
                _collect_aliases_rec_list(item)

    def _collect_aliases_rec_dict(node: Dict):
        for key in node.keys():
            if (
                (key == "table" or key == "tableLike" or key == "from")  # noqa: PLR0916
                and isinstance(node[key], dict)
                and "type" in node[key]
                and node[key]["type"] == "TABLE"
            ):
                _collect_valid_aliases(node[key])
            else:
                if isinstance(node[key], dict):
                    _collect_aliases_rec_dict(node[key])
                elif isinstance(node[key], list):
                    _collect_aliases_rec_list(node[key], is_cte = key == "with")

    def _expand_table_node(node: Dict):
        assert node["type"] == "TABLE"
        table_name = node["name"]

        if table_name not in dataset_names:
            if table_name in cte_aliases:
                return # Dont' need to expand CTE aliases
            if table_name not in valid_aliases:
                raise Exception("Illegal table access: %s" % table_name)
            else:
                # Leave the alias alone
                return

        location_info = dataiku.Dataset(table_name).get_location_info()
        if location_info.get("locationInfoType") != "SQL":
            raise ValueError("Can only execute query on an SQL dataset")

        catalog_name = location_info.get("info").get("catalog")
        schema_name = location_info.get("info").get("schema")
        table_name = location_info.get("info").get("table")
        tables_used.add(table_name)

        if catalog_name is not None:
            node["catalog"] = catalog_name
        if schema_name is not None:
            node["schema"] = schema_name
        if table_name is not None:
            node["name"] = table_name

    def _replace_tables_rec_list(node: List):
        for item in node:
            if isinstance(item, dict):
                _replace_tables_rec_dict(item)
            elif isinstance(item, list):
                _replace_tables_rec_list(item)

    def _replace_tables_rec_dict(node: Dict):
        for key in node.keys():
            if (
                (key == "table" or key == "tableLike" or key == "from")  # noqa: PLR0916
                and isinstance(node[key], dict)
                and "type" in node[key]
                and node[key]["type"] == "TABLE"
            ):
                _expand_table_node(node[key])
            else:
                if isinstance(node[key], dict):
                    _replace_tables_rec_dict(node[key])
                elif isinstance(node[key], list):
                    _replace_tables_rec_list(node[key])

    _collect_aliases_rec_dict(ast._query)
    _replace_tables_rec_dict(ast._query)
    return list(tables_used)
