import copy
from typing import Any, Dict, List, Tuple, Union

import dataiku
from answers.backend.utils.db.sql_query_utils import replace_tables_in_ast
from common.backend.db.sql.queries import columns_in_uppercase
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger
from dataiku import SQLExecutor2
from dataiku.core.dataset import Dataset
from dataiku.sql import SelectQuery, toSQL
from langchain.tools import BaseTool

PROJECT_KEY = dataiku_api.default_project_key
SELECTED_CONN = dataiku_api.webapp_config.get("sql_retrieval_connection")
PROJECT = dataiku.api_client().get_project(PROJECT_KEY)


def format_table_description(
    dataset_name: str,
    connection_type: str,
    short_desc: str,
    schema: List,
    cols_uppercase: bool
    ) -> str:
    full_text = (
        f"Name: {dataset_name}\n"
        f"SQL Database type: {connection_type}\n"
        f"Datasource description: {short_desc}\n\n"
        "## Columns description\n"
    )
    description_missing = False
    for col in schema:
        col_name = col["name"].upper() if cols_uppercase else col["name"]
        comment = col.get("comment", "")
        if not comment:
            description_missing = True
        full_text += f"""
        - Name: '{col_name}' | Type: '{col["type"]}' | Description: '{comment}'
        """
    full_text += "-- End of Description --\n"
    if description_missing:
        logger.warn(
        "It is recommended that all columns have a description. click dataset > settings > schema then click on a column and edit the DESCRIPTION on the right hand side"
    )
    return full_text


def get_dataset_description(dataset_name: str, reduced_schema: List[str] = []) -> str:
    dataset = dataiku.Dataset(dataset_name, project_key=PROJECT_KEY)
    connection_type = dataset.get_config().get("type")
    cols_uppercase = columns_in_uppercase(dataset)
    dataset_config = dataset.get_config()
    short_desc = dataset_config.get("shortDesc")
    long_desc = dataset_config.get("description")
    if short_desc is None and long_desc is None:
        logger.warn(
            "Dataset must have description. Add one under details > short description / details > Long description"
        )
    schema = dataset_config.get("schema", {}).get("columns", {})
    if reduced_schema:
        schema = [col for col in schema if col.get("name", "") in reduced_schema]
    description = long_desc or short_desc
    return format_table_description(dataset_name, connection_type, description, schema, cols_uppercase)


def get_all_dataset_descriptions() -> str:
    dataset_descriptions = ""
    sql_retrieval_table_list: List[str] = list(dataiku_api.webapp_config.get("sql_retrieval_table_list", []))
    if len(sql_retrieval_table_list) > 0:
        for dataset in sql_retrieval_table_list:
            dataset_descriptions += get_dataset_description(dataset)
            dataset_descriptions += "---------------------------------- \n"
        return dataset_descriptions
    raise Exception("At least one SQL table is required.")


def get_all_suggested_joins() -> str:
    all_suggested_joins = ""
    sql_retrieval_suggested_joins: Union[List[Dict[str, str]], str] = dataiku_api.webapp_config.get("sql_retrieval_suggested_joins", [])
    if isinstance(sql_retrieval_suggested_joins, str):
        sql_retrieval_suggested_joins = []
    if len(sql_retrieval_suggested_joins) > 0:
        for suggestion in sql_retrieval_suggested_joins:
            left_dataset, left_column = suggestion.get("sql_left_column", "").split(".")
            right_dataset, right_column = suggestion.get("sql_right_column", "").split(".")
            all_suggested_joins += f"{left_dataset}.{left_column} = {right_dataset}.{right_column} \n"
        return all_suggested_joins
    return "None"


def get_dataset_descriptions_from_table_names(tables_used: List[str]) -> str:
    datasets = PROJECT.list_datasets()
    connection_datasets = [d["name"] for d in datasets if d["params"].get("connection", "") == SELECTED_CONN]
    used_datasets_descriptions = ""
    for table in tables_used:
        for dataset_name in connection_datasets:
            ds = Dataset(dataset_name)
            ds_table = ds.get_location_info()["info"].get("table","")
            if table == ds_table:
                used_datasets_descriptions += get_dataset_description(dataset_name)
                used_datasets_descriptions += "---------------------------------- \n"
    return used_datasets_descriptions

def get_user_query_examples() -> str:
    formatted_examples = ""
    if not SELECTED_CONN:
        raise Exception("A SQL connection is expected to be be able to add examples")
    db_query_examples: Dict[str, str] = dataiku_api.webapp_config.get("db_query_examples", {}) # type: ignore
    if len(db_query_examples) < 1:
        return formatted_examples
    formatted_examples += """# Typical Expected Queries
    Below are some typical user questions and the their expected SQL queries that have previously been asked about this dataset:
    """
    for idx, (question, query) in enumerate(db_query_examples.items()):
        formatted_examples += f"""
        {idx+1}.
        user Questions:{question}
        Expected SQL query:{query}
        """
    formatted_examples+=""""
    # --- END OF TYPICAL QUERIES ---
    """
    return formatted_examples

def to_select_query(db_query: Dict[str, Any], hard_sql_limit: int = 200, is_cte: bool = False) -> SelectQuery:
    select_query = SelectQuery()
    q = copy.deepcopy(select_query._query)
    if is_cte and "with" in q:
        del q["with"] # CTEs cannot have CTEs.

    select_list = db_query.get("selectList")
    if isinstance(select_list, list):
        q["selectList"] = select_list
    else:
        q["selectList"] = [{"expr": {"type": "COLUMN", "name": "*"}}]

    from_ = db_query.get("from")
    if isinstance(from_, dict):
        q["from"] = from_
    else:
        raise Exception("A query must have 'from' statement")

    q["alias"] = db_query.get("alias") # alias of a _query is mainly used for CTEs

    for key in ["with", "join", "where", "groupBy", "having", "orderBy"]:
        val = db_query.get(key)
        if val and isinstance(val, list):
            if key == "with":
                q["with"] = [copy.deepcopy(to_select_query(cte, hard_sql_limit=hard_sql_limit, is_cte=True)._query) for cte in val]
            else:
                q[key] = val

    records = []
    limit = db_query.get("limit")
    if isinstance(limit, int):
        if limit < hard_sql_limit:
            q["limit"] = limit
        else:
            records.append(
                [{
                    "WARNING!": f"The query exceeded the query limit of {hard_sql_limit}. Some information may be missing. Warn the user!"
                }]
            )
            q["limit"] = hard_sql_limit
    else:
        q["limit"] = hard_sql_limit
    select_query._query = q
    return select_query

class SqlRetrieverTool(BaseTool):

    def __init__(self):
        name = "SQL retriever"
        description = "The SqlRetrieverTool is designed for retrieving data from SQL databases. It leverages filtering criteria to fetch specific records. It utilizes SQL queries to extract the first matching record based on the given filters."
        super().__init__(name=name, description=description)

    def _run(self, db_query: dict) -> Tuple[str, Any, List[List[Dict[str, str]]], List]:
        connection_name = dataiku_api.webapp_config.get("sql_retrieval_connection")
        if connection_name is None:
            raise Exception("A SQL connection selection is required to run a query.")
        if not(sql_retrieval_table_list := dataiku_api.webapp_config.get("sql_retrieval_table_list", [])):
            raise Exception("At least one SQL table is required to run a query.")
        # The first dataset is used by toSQL to get the dialect information for the connection
        first_dataset = Dataset(sql_retrieval_table_list[0])
        hard_sql_limit = int(dataiku_api.webapp_config.get("hard_sql_limit", 200))
        logger.debug(f"hard_sql_limit is set to {hard_sql_limit}")

        executor = SQLExecutor2(connection=connection_name)
        select_query = to_select_query(db_query, hard_sql_limit)

        tables_used: List[str] = replace_tables_in_ast(select_query)
        logger.debug(f"Replaced ast: {select_query}")

        sql_query = toSQL(select_query, dataset=first_dataset)
        logger.debug(f"Running SQL Query: {sql_query}")
        df = executor.query_to_df(query=sql_query)
        df.fillna("", inplace=True)
        records = df.to_dict("records")

        used_dataset_descriptions = get_dataset_descriptions_from_table_names(tables_used)
        context = f"""
        SQL query executed
        {sql_query}
        Dataset Descriptions
        {used_dataset_descriptions}
        
        >response
        {records}
        """
        return context, sql_query, records[:10], tables_used

    def _arun(self, filter_dict: dict) -> dict:  # type: ignore
        raise NotImplementedError("This tool does not support async")
