import copy
import logging
import re
from typing import Any, Dict, List

import dataiku
from dataiku.sql import SelectQuery


def is_table_key(key: str) -> bool:
    if key == "table" or key == "tableLike" or key == "from":
        return True
    return False


def replace_tables_in_select_query(ast: SelectQuery, datasets: List[str]) -> None:
    valid_aliases, cte_aliases = set(), set()

    def _collect_valid_aliases(node: Dict) -> None:
        assert node["type"] == "TABLE"

        if "alias" in node and node["name"] in datasets:
            valid_aliases.add(node["alias"])

    def _collect_aliases_rec_list(node: List, is_cte: bool = False) -> None:
        for item in node:
            if is_cte and (cte_alias := item.get("alias")):
                cte_table_name = item["from"]["name"]
                if cte_table_name not in datasets:
                    raise Exception(f"CTE Illegal table access: {cte_table_name}")
                cte_aliases.add(cte_alias)
            if isinstance(item, dict):
                _collect_aliases_rec_dict(item)
            elif isinstance(item, list):
                # We don't use is_cte here because we can't have recursive CTEs
                # check dip src/main/java/com/dataiku/dip/sql/queries/QuerySQLWriter.java > writeCTEs
                _collect_aliases_rec_list(item)

    def _collect_aliases_rec_dict(node: Dict) -> None:
        for key in node.keys():
            if (
                is_table_key(key)
                and isinstance(node[key], dict)
                and "type" in node[key]
                and node[key]["type"] == "TABLE"
            ):  # noqa: PLR0916
                _collect_valid_aliases(node[key])
            else:
                if isinstance(node[key], dict):
                    _collect_aliases_rec_dict(node[key])
                elif isinstance(node[key], list):
                    _collect_aliases_rec_list(node[key], is_cte=key == "with")

    def _expand_table_node(node: Dict) -> None:
        assert node["type"] == "TABLE"
        table_name = node["name"]

        if table_name not in datasets:
            if table_name in cte_aliases:
                return  # Dont' need to expand CTE aliases
            if table_name not in valid_aliases:
                raise Exception(f"Illegal table access: {table_name}")
            else:
                # Leave the alias alone
                return

        location_info = dataiku.Dataset(table_name).get_location_info()
        if location_info.get("locationInfoType") != "SQL":
            raise ValueError("Cannot only execute query on an SQL dataset")

        catalog_name = location_info.get("info").get("catalog")
        schema_name = location_info.get("info").get("schema")
        table_name = location_info.get("info").get("table")

        if catalog_name is not None:
            node["catalog"] = catalog_name
        if schema_name is not None:
            node["schema"] = schema_name
        if table_name is not None:
            node["name"] = table_name

    def _replace_tables_rec_list(node: List) -> None:
        for item in node:
            if isinstance(item, dict):
                _replace_tables_rec_dict(item)
            elif isinstance(item, list):
                _replace_tables_rec_list(item)

    def _replace_tables_rec_dict(node: Dict) -> None:
        for key in node.keys():
            if (
                is_table_key(key)
                and isinstance(node[key], dict)
                and "type" in node[key]
                and node[key]["type"] == "TABLE"
            ):  # noqa: PLR0916
                _expand_table_node(node[key])
            else:
                if isinstance(node[key], dict):
                    _replace_tables_rec_dict(node[key])
                elif isinstance(node[key], list):
                    _replace_tables_rec_list(node[key])

    _collect_aliases_rec_dict(ast._query)
    logging.info(f"Collected valid aliases: {valid_aliases}")
    _replace_tables_rec_dict(ast._query)


def to_select_query(db_query: Dict[str, Any], hard_sql_limit: int = 200, is_cte: bool = False) -> SelectQuery:
    select_query = SelectQuery()
    q = copy.deepcopy(select_query._query)
    if is_cte and "with" in q:
        del q["with"]  # CTEs cannot have CTEs.

    select_list = db_query.get("selectList")
    if isinstance(select_list, list):
        q["selectList"] = select_list
    else:
        q["selectList"] = [{"expr": {"type": "COLUMN", "name": "*"}}]

    from_ = db_query.get("from")
    if isinstance(from_, dict):
        q["from"] = from_
    else:
        raise Exception("A query must have 'from' statement")

    q["alias"] = db_query.get("alias")  # alias of a _query is mainly used for CTEs

    for key in ["with", "join", "where", "groupBy", "having", "orderBy"]:
        val = db_query.get(key)
        if val and isinstance(val, list):
            if key == "with":
                q["with"] = [
                    copy.deepcopy(to_select_query(cte, hard_sql_limit=hard_sql_limit, is_cte=True)._query)
                    for cte in val
                ]
            else:
                q[key] = val

    records = []
    limit = db_query.get("limit")
    if isinstance(limit, int):
        if limit < hard_sql_limit:
            q["limit"] = limit
        else:
            records.append(
                [
                    {
                        "WARNING!": f"The query exceeded the query limit of {hard_sql_limit}. Some information may be missing. Warn the user!"
                    }
                ]
            )
            q["limit"] = hard_sql_limit
    else:
        q["limit"] = hard_sql_limit
    select_query._query = q
    return select_query


def is_dataset_access_error(error_message: str) -> bool:
    """Check if the error message indicates a dataset/connection access permission issue."""
    if "You may not submit queries to this connection" in error_message:
        return True
    if re.search(r"User \S+ does not have credentials for connection \S+ to access \S+", error_message):
        return True
    return False
