from __future__ import annotations

import json
import logging
import traceback
from typing import List

import kuzu
from dataikuapi.dss.llm import DSSLLM
from pandas import DataFrame

from solutions.graph.graph_db_instance_manager import AbstractDbInstance
from solutions.graph.models import EdgeGroupDefinition, NodeGroupDefinition
from solutions.graph.views.cypher_graph_view import CypherGraphView, CypherQueryExecutionException
from solutions.llm.prompts import get_retry_system_prompt, get_system_prompt, get_user_prompt

logger = logging.getLogger(__name__)

LLM_RETRY = 3


class AIQueryExecutionException(Exception):
    """
    Exception raised when an error occurs during the execution of an AI-generated Cypher query.

    Attributes:
        query (str): The Cypher query that caused the exception.
        cypher_error (str): The error message returned by the Cypher engine (first line only).

    Args:
        query (str): The generated Cypher query that failed.
        cypher_error (str): The error message from the Cypher engine.

    Example:
        raise AIQueryExecutionException(query, cypher_error)
    """

    def __init__(self, query: str, title: str, cypher_error: str):
        self.query = query
        self.title = title
        self.cypher_error = cypher_error

        message = self.cypher_error

        super().__init__(message)


class LLMCypher:
    def __init__(
        self,
        db_instance: AbstractDbInstance,
        llm: DSSLLM,
        node_definitions: List[NodeGroupDefinition],
        edge_definitions: List[EdgeGroupDefinition],
    ):
        self.__db_instance = db_instance
        self.__llm = llm

        self.__node_definitions = node_definitions
        self.__edge_definitions = edge_definitions
        self.__graph_view = CypherGraphView(db_instance, node_definitions, edge_definitions)

        self.__graph_schema: str | None = None

    def execute_llm_with_retry(
        self,
        question: str,
        cypher_query: str | None = None,
        title: str | None = None,
        retry: int = LLM_RETRY,
        timeout_seconds: int | None = None,
        last_cypher_exception: AIQueryExecutionException | None = None,
    ) -> dict:
        """
        Executes a Cypher query, with retries on failure.

        If a query is not provided, it generates one from the natural language question.

        Args:
            question (str): The natural language question.
            cypher_query (str | None, optional): An optional Cypher query to execute.
            title (str | None, optional): An optional title for the query.
            retry (int, optional): Number of retries allowed. Defaults to LLM_RETRY.
            timeout_seconds (int, optional): Query timeout in seconds. Defaults to None.
            last_cypher_exception (AIQueryExecutionException optional): The last exception encountered during Cypher execution, used for error reporting. Defaults to None.

        Returns:
            dict: The formatted result, including nodes, edges, a table, the query, and title.

        Raises:
            Exception: If all retries are exhausted.
        """
        if retry < 0:
            if isinstance(last_cypher_exception, AIQueryExecutionException):
                raise last_cypher_exception
            else:
                raise Exception(f"Error with query after multiple retries.\n{traceback.format_exc()}")

        current_title = title
        current_cypher_query = cypher_query

        if not current_cypher_query:
            try:
                llm_response = self.generate_cypher(question)
                current_cypher_query = llm_response["cypher_query"]
                current_title = llm_response["title"]
            except (json.JSONDecodeError, KeyError) as e:
                logger.error(f"Failed to parse LLM response: {e}. Retrying...")
                # Retry without a "previous query" since the response was malformed
                return self.execute_llm_with_retry(question=question, retry=retry - 1, timeout_seconds=timeout_seconds)

        if not current_title:
            current_title = "User-provided query"

        try:
            logger.debug(f"Execute cypher query: {current_cypher_query}")
            try:
                self.__graph_view.execute(current_cypher_query, timeout_seconds=timeout_seconds)
            except CypherQueryExecutionException as e:
                raise AIQueryExecutionException(query=current_cypher_query, title=current_title, cypher_error=str(e))

            return self.__format_graph_view_result(current_cypher_query, current_title)

        except Exception as e:
            error_trace = traceback.format_exc()
            logger.debug(f"Cypher query failed, retrying. Retry: {retry}.\n{error_trace}")

            try:
                llm_response = self.generate_cypher(question, current_cypher_query, error_trace)
                new_cypher_query = llm_response["cypher_query"]
                new_title = llm_response["title"]
                # Recursively call with the new query and title
                return self.execute_llm_with_retry(
                    question=question,
                    cypher_query=new_cypher_query,
                    title=new_title,
                    retry=retry - 1,
                    timeout_seconds=timeout_seconds,
                    last_cypher_exception=e if isinstance(e, AIQueryExecutionException) else None,
                )
            except (json.JSONDecodeError, KeyError) as e:
                logger.error(f"Failed to parse LLM response on retry: {e}. Retrying...")
                return self.execute_llm_with_retry(question=question, retry=retry - 1, timeout_seconds=timeout_seconds)

    def generate_cypher(
        self, llm_request: str, previous_query: str | None = None, error_message: str | None = None
    ) -> dict[str, str]:
        """
        Generates a Cypher query and a title using an LLM.

        Args:
            llm_request (str): The user's request.
            previous_query (str | None, optional): The previous failed query for context.
            error_message (str | None, optional): The error message from the failed query.

        Returns:
            dict: A dictionary with "title" and "cypher_query" keys.

        Raises:
            json.JSONDecodeError: If the LLM response is not valid JSON.
            KeyError: If the JSON response is missing required keys.
        """
        if not self.__graph_schema:
            self.__graph_schema = self.get_graph_schema()

        if previous_query and error_message:
            system_prompt = get_retry_system_prompt(self.__graph_schema, llm_request, previous_query, error_message)
        else:
            system_prompt = get_system_prompt()

        user_prompt = get_user_prompt(self.__graph_schema, llm_request)
        completion = self.__llm.new_completion()

        logger.debug(f"Generate cypher from system prompt: {system_prompt}")
        completion.with_message(system_prompt, role="system")

        logger.debug(f"Generate cypher from user prompt: {user_prompt}")
        completion.with_message(user_prompt, role="user")

        resp = completion.execute()

        logger.debug(f"LLM raw response: {resp.text}")
        parsed_response: dict[str, str] = json.loads(resp.text)  # type: ignore

        if "title" not in parsed_response or "cypher_query" not in parsed_response:
            raise KeyError("LLM response is missing 'title' or 'cypher_query' key.")

        return parsed_response

    def get_graph_schema(self) -> str:
        """
        Retrieves and formats the schema of the graph database as a human-readable string.

        This method queries the database for all tables (nodes and relationships), their properties,
        and, for relationships, their source and destination connections. The schema is formatted
        to display each node and relationship, their properties, and the connections for relationships.

        Returns:
            str: A formatted string describing the graph database schema, including nodes, relationships,
                 their properties, and relationship connections.
        """
        tables = self.__execute_query("CALL show_tables() RETURN *;")

        nodes = []
        edges = []

        for _, table in tables.iterrows():
            current_obj = ""
            table_prop = self.__execute_query(f"""CALL TABLE_INFO('{table["name"]}') RETURN *;""")
            if table["type"] == "NODE":
                current_obj += f"""{table["name"]}"""
                current_props = []

                table_prop = self.__execute_query(f"""CALL TABLE_INFO('{table["name"]}') RETURN *;""")
                for _, prop in table_prop.iterrows():
                    prop_name = prop["name"]
                    if not prop_name.startswith("_dku") or prop_name in ["source", "source type"]:
                        current_props.append(f"""({prop_name}: {prop["type"]})""")

                props = ",".join(current_props)
                current_obj += f""": {props}"""

                nodes.append(current_obj)

            elif table["type"] == "REL":
                rel_conns = self.__execute_query(f"""CALL SHOW_CONNECTION('{table["name"]}') RETURN *;""")
                for _, c in rel_conns.iterrows():
                    details = c.to_dict()
                    current_obj += f"""(:{details["source table name"]})-[:{table["name"]}]->(:{details["destination table name"]})"""

                edges.append(current_obj)

        node_lines = "\n".join(nodes)
        edge_lines = "\n".join(edges)

        formatted_schema = f"""
        NODE TYPES:
            {node_lines}
        RELATIONSHIPS:
            {edge_lines}
        """

        logger.debug(f"Database Schema.\n{formatted_schema}")
        return formatted_schema

    def __format_graph_view_result(self, cypher_query: str, title: str) -> dict:
        """
        Formats the graph view result into a structured dictionary.

        Args:
            cypher_query (str): The Cypher query used.
            title (str): The descriptive title for the query.

        Returns:
            dict: A dictionary containing the query results and metadata.
        """

        # Get unique node IDs and edge IDs in node definitionsAdd commentMore actions
        node_ids = set([d["node_id"] for d in self.__node_definitions])
        edge_ids = set([d["edge_id"] for d in self.__edge_definitions])

        nodes = []
        for node_id in node_ids:
            nodes.extend(self.__graph_view.get_nodes(node_id))

        edges = []
        for edge_id in edge_ids:
            edges.extend(self.__graph_view.get_edges(edge_id))

        df = self.__graph_view.get_as_df().fillna("")

        return {
            "success": True,
            "title": title,
            "nodes": nodes,
            "edges": edges,
            "table": {"columns": [{"name": col} for col in df.columns], "rows": df.to_dict("records")},
            "cypher_query": cypher_query,
        }

    def __execute_query(self, query: str) -> DataFrame:
        """
        Executes a Cypher query against the Kuzu database and returns the result as a DataFrame.

        Args:
            query (str): The Cypher query string to execute.

        Returns:
            DataFrame: The result of the query as a pandas DataFrame.

        Raises:
            CypherQueryExecutionException: If the query execution fails or returns an unexpected result type.

        Notes:
            - Ensures that the QueryResult is explicitly closed to avoid semaphore-related errors.
            - Handles both single and list of QueryResult objects, though only single QueryResult is expected.
        """
        result: kuzu.QueryResult | List[kuzu.QueryResult] | None = None
        with self.__db_instance.get_new_conn() as conn:
            try:
                result = conn.connection.execute(query)
                if isinstance(result, kuzu.QueryResult):
                    df_result = result.get_as_df()
                else:
                    raise Exception("Unexpected kuzu result type.")
            except Exception as ex:
                logger.info(f"Failed to execute query {query}.")
                raise CypherQueryExecutionException(str(ex))
            finally:
                # Still unsure of the origin:
                # if I dont explicitely close the QueryResult manually, it fails with an error regarding semaphores.
                # Might be due to the order in which the connection and the QueryResult are closed when holding variables go out-of-scope.
                if result:
                    if isinstance(result, kuzu.QueryResult):
                        try:
                            result.close()
                        except Exception as ex:
                            logger.exception("Failed to close the query result.")
                    else:
                        for q in result:
                            try:
                                q.close()
                            except Exception as ex:
                                logger.exception("Failed to close one of the query result.")
        return df_result
