from __future__ import annotations

import hashlib
import logging
import os
import threading
from pathlib import Path
from typing import List, Optional

import kuzu
from dataiku import Folder
from dataiku.core.base import is_container_exec
from networkx import MultiDiGraph
from pandas import DataFrame

from solutions.graph.dataiku.utils import get_s3_credentials
from solutions.graph.graph_db_instance_manager import (
    SUPPORTED_STORAGES,
    AbstractDbInstance,
    LocalDbInstance,
    LocalReplicaDbInstance,
    S3RemoteDbInstance,
)
from solutions.graph.kuzu.query_result_extensions import get_as_networkx
from solutions.graph.models import EdgeGroupDefinition, ExplorerMetadata, GraphId, NodeGroupDefinition

logger = logging.getLogger(__name__)

PATH_TO_SNAPSHOT_JSON = "built-graphs/"
SNAPSHOT_JSON_FILE_NAME = "configuration.json"
BUILD_INFO_JSON_FILE_NAME = "buildInfo.json"

EXPLORER_DB_DOWNLOAD_LOCK = threading.Lock()


class DbInstanceFactory:
    def __init__(self, db_folders: List[Folder]):
        self.db_folders = db_folders

    def get_db_instance_from_db_path(self, db_folder: Folder, relative_path_to_db: str) -> AbstractDbInstance:
        """
        Returns a database instance corresponding to the given folder and relative database path.

        Depending on the type of the provided `db_folder`, this method will instantiate and return
        either a `LocalDbInstance` (for filesystem folders) or an `S3RemoteDbInstance` (for S3 folders)
        or a `LocalReplicaDbInstance` (for container based processes).
        The database instance is always opened in read-only mode.

        Args:
            db_folder (Folder): The folder object representing the root location of the database.
            relative_path_to_db (str): The relative path to the database file within the folder.

        Returns:
            AbstractDbInstance: An instance of the appropriate database class (local or S3).

        Raises:
            Exception: If the folder type is not supported (i.e., not "Filesystem" or "S3").
        """
        db_folder_type = db_folder.get_info(sensitive_info=True)["type"]
        if db_folder_type == "Filesystem" and not is_container_exec():
            path_to_db = Path(os.path.join(db_folder.get_path(), relative_path_to_db))
        else:
            path_to_db = Path(relative_path_to_db)

        return self.get_db_instance(path_to_db, db_folder)

    def get_db_folder_by_snapshot_id(self, snapshot_id: str) -> Folder:
        """
        Retrieves the first database folder found corresponding to the given snapshot ID.
        In case there are other folders matching, log a warning.

        Args:
            snapshot_id (str): The identifier of the saved configuration for which to retrieve the database folder.

        Returns:
            Folder: The folder containing the database files for the specified snapshot ID.
        """
        candidate_folder = None
        for folder in self.db_folders:
            if f"/built-graphs/{snapshot_id}/db.kz" in folder.list_paths_in_partition():
                if not candidate_folder:
                    candidate_folder = folder
                else:
                    logger.warning(
                        "Multiple folders were found containing same saved configuration id. The first configuration found will be used."
                    )

        if not candidate_folder:
            raise Exception("No folder contains the given saved configuration id")
        return candidate_folder

    def get_db_instance_from_snapshot_id(self, snapshot_id: str) -> AbstractDbInstance:
        """
        Returns a database instance corresponding to the given saved configuration ID and the type of the configured database folder.

        Depending on the type of the database folder (`Filesystem` or `S3`), this method constructs the appropriate path or credentials
        and returns an instance of either `LocalDbInstance` or `S3RemoteDbInstance`. If the folder type is not supported, an exception is raised.

        Args:
            snapshot_id (str): The identifier of the saved configuration for which to retrieve the database instance.

        Returns:
            AbstractDbInstance: An instance of the database (local or remote) corresponding to the saved configuration.

        Raises:
            Exception: If the database folder type is not supported (i.e., not "Filesystem" or "S3").
        """
        db_folder_containing_snapshot = self.get_db_folder_by_snapshot_id(snapshot_id)
        relative_path_to_db = Path(f"built-graphs/{snapshot_id}")
        logging.info(f"Retrieve db from path : {relative_path_to_db}")

        db_folder_type = db_folder_containing_snapshot.get_info(sensitive_info=True)["type"]
        if db_folder_type == "Filesystem" and not is_container_exec():
            path_to_db = Path(os.path.join(db_folder_containing_snapshot.get_path(), relative_path_to_db, "db.kz"))
        else:
            path_to_db = Path(os.path.join(relative_path_to_db, "db.kz"))

        return self.get_db_instance(path_to_db, db_folder_containing_snapshot)

    def get_db_instance(self, path_to_db: Path, db_folder: Folder) -> AbstractDbInstance:
        """
        Returns a database instance based on the type of the provided folder.

        Args:
            path_to_db (Path): The path to the database.
            db_folder (Folder): The folder object representing the database storage location.

        Returns:
            AbstractDbInstance: An instance of the appropriate database class (LocalDbInstance, S3RemoteDbInstance, or LocalReplicaDbInstance).

        Raises:
            Exception: If the folder type is not supported.
        """
        db_folder_type = db_folder.get_info(sensitive_info=True)["type"]
        if db_folder_type not in SUPPORTED_STORAGES:
            raise Exception(
                f"Output folder {db_folder_type} not supported. It should be {', '.join(SUPPORTED_STORAGES)}."
            )

        if db_folder_type == "Filesystem" and not is_container_exec():
            return LocalDbInstance(path_to_db, readonly=True)
        elif db_folder_type == "S3":
            return S3RemoteDbInstance(get_s3_credentials(db_folder), path_to_db, readonly=True)
        else:
            return LocalReplicaDbInstance(path_to_db, db_folder, readonly=True)


class ExplorerDbInstanceFactory(DbInstanceFactory):
    """
    Factory class for creating database instance objects specifically for the explorer context.

    This class ensures that only supported storage types are used for database instances,
    explicitly avoiding the use of S3RemoteDbInstance due to known concurrency issues in the explorer context.

    """

    def get_db_instance(self, path_to_db: Path, db_folder: Folder) -> AbstractDbInstance:
        db_folder_type = db_folder.get_info(sensitive_info=True)["type"]
        if db_folder_type not in SUPPORTED_STORAGES:
            raise Exception(
                f"Output folder {db_folder_type} not supported. It should be {', '.join(SUPPORTED_STORAGES)}."
            )

        if db_folder_type == "Filesystem" and not is_container_exec():
            return LocalDbInstance(path_to_db, readonly=True)
        else:
            # Guard against the following situtation that happens only the first time the graph is visited:
            # 1. several concurrent requests to a graph without a replica downloaded yet
            # 2. the first request will trigger the download
            # 3. the subsequent ones used to use a malformed or non existing db, leading to empty results
            # To cover for this situation in a multi-threaded environment, we use a lock so the subsequent requests wait for the download to be complete.
            # This hack(-ish) fix will not work in a multi process environment though.
            # That's ok given the timeline we have to release the v1.0.0 of Visual Graph.
            with EXPLORER_DB_DOWNLOAD_LOCK:
                return LocalReplicaDbInstance(path_to_db, db_folder, readonly=True)


class ExplorerMetadataBuilder:
    def __init__(self, db_instance: AbstractDbInstance, snapshot_id: str):
        self.db_instance = db_instance
        self.snapshot_id = snapshot_id

    def build(self) -> ExplorerMetadata:
        """
        Builds and returns an ExplorerMetadata object by extracting node and edge definitions
        from the database schema.
        Connects to the database in read-only mode, retrieves table information using a
        database-specific query, and processes each column to populate node and edge group
        definitions. The resulting ExplorerMetadata includes metadata such as saved configuration ID,
        node and edge definitions, and views for nodes and edges.
        Returns:
            ExplorerMetadata: The constructed metadata object containing graph schema details.
        Raises:
            Exception: If the database query result is not of the expected type.
        """
        node_definitions: List[NodeGroupDefinition] = []
        edge_definitions: List[EdgeGroupDefinition] = []

        df = self.__execute_kuzu_query_df("CALL SHOW_TABLES() RETURN *;")
        df_dict = df.to_dict(orient="records")

        for col in df_dict:
            self._process_col(col, node_definitions, edge_definitions)

        return ExplorerMetadata(
            snapshot_id=self.snapshot_id,
            name="Unknown",
            comment="",
            epoch_ms=0,
            node_definitions=node_definitions,
            edge_definitions=edge_definitions,
            nodes_view=self.get_nodes_view(node_definitions),
            edges_view=self.get_edges_view(edge_definitions),
            cypher_queries=[],
        )

    def get_nodes_view(self, node_definitions: list) -> dict:
        """
        Generates a view dictionary for nodes based on provided node definitions.

        Args:
            node_definitions (list): A list of dictionaries, each containing at least a "node_id" key.

        Returns:
            dict: A dictionary where each key is a node ID and each value is a dictionary with the following keys:
                - "color": A color string generated from the node ID using ExplorerColorizer.
                - "size": The string "normal" (default size).
                - "icon": An empty string (placeholder for icon).
        """
        node_ids = [node["node_id"] for node in node_definitions]
        return {
            node: {"color": ExplorerColorizer.string_to_beautiful_color(node), "size": "normal", "icon": ""}
            for node in node_ids
        }

    def get_edges_view(self, edge_definitions: list) -> dict:
        """
        Generates a view dictionary for edges based on provided edge definitions.

        Args:
            edge_definitions (list): A list of dictionaries, each containing an "edge_id" key.

        Returns:
            dict: A dictionary where each key is an edge ID and the value is a dictionary with a "size" attribute set to 1.
        """
        edge_ids = [edge["edge_id"] for edge in edge_definitions]
        return {edge: {"size": 1} for edge in edge_ids}

    def _process_col(
        self, col: dict, node_definitions: List[NodeGroupDefinition], edge_definitions: List[EdgeGroupDefinition]
    ):
        """
        Processes a column definition to extract node or edge group definitions from the graph database.

        Depending on the column type ("NODE" or "REL"), this method queries the graph database for a sample node or edge,
        converts the result to a NetworkX graph, and appends the corresponding group definition to the provided lists.

        Args:
            col (dict): The column definition dictionary, expected to contain at least "name" and "type" keys.
            node_definitions (List[NodeGroupDefinition]): List to append discovered NodeGroupDefinition objects to.
            edge_definitions (List[EdgeGroupDefinition]): List to append discovered EdgeGroupDefinition objects to.

        Raises:
            Exception: If the query result is not of the expected kuzu.QueryResult type.

        Side Effects:
            Modifies the node_definitions or edge_definitions lists in place by appending new group definitions.
        """
        name = col["name"]

        if col["type"] == "NODE":
            # Query to find all the distinct nodes regarding _dku_grp_def_id. Each one is a node definition.
            query = f"MATCH (o:`{name}`) WITH o._dku_grp_def_id AS groupId, collect(o) AS objs RETURN objs[1] AS obj;"
            graph = self.__execute_kuzu_query_graph(query)
            for node, data in graph.nodes(data=True):
                node_id = node.split("~")[0]
                node_definitions.append(
                    NodeGroupDefinition(
                        definition_id=data["_dku_grp_def_id"],
                        node_group=data["_label"],
                        node_id=node_id,
                        source_dataset="",
                        primary_col=data["_dku_reserved_user_defined_pk_prop"],
                        label_col=data["_dku_reserved_user_defined_pk_prop"],
                        property_list=[k for k in data if not k.startswith("_")],
                        filters_association="and",
                        filters_stored=[],
                    )
                )
        else:
            # Query to find all the distinct edges regarding _dku_grp_def_id
            query = f"MATCH (o)-[e:`{name}`]->(d) WITH e._dku_grp_def_id AS groupId, collect({{source: o, rel: e, destination: d}}) AS paths WITH paths[1] AS firstPath RETURN firstPath.source AS o, firstPath.rel AS e, firstPath.destination AS d;"
            graph = self.__execute_kuzu_query_graph(query)
            for source, target, data in graph.edges(data=True):
                source_id = source.split("~")[0]
                target_id = target.split("~")[0]
                edge_definitions.append(
                    EdgeGroupDefinition(
                        definition_id=data["_dku_grp_def_id"],
                        edge_dataset="",
                        edge_group=data["_label"],
                        edge_id=data["_dku_grp_id"],
                        source_node_id=source_id,
                        target_node_id=target_id,
                        source_column=graph.nodes[source]["_dku_reserved_user_defined_pk_prop"],
                        target_column=graph.nodes[target]["_dku_reserved_user_defined_pk_prop"],
                        source_node_group=graph.nodes[source]["_label"],
                        target_node_group=graph.nodes[target]["_label"],
                        property_list=[k for k in data if not k.startswith("_")],
                        filters_association="and",
                        filters_stored=[],
                    )
                )

    def __execute_kuzu_query_graph(self, query: str) -> MultiDiGraph:
        result: kuzu.QueryResult | List[kuzu.QueryResult] | None = None
        with self.db_instance.get_new_conn() as conn:
            try:
                result = conn.connection.execute(query)
                if isinstance(result, kuzu.QueryResult):
                    graph = get_as_networkx(result)
                else:
                    raise Exception("Unexpected kuzu result type.")
            except Exception as ex:
                logger.info(f"Failed to execute query {query}.")
                raise Exception(str(ex))
            finally:
                # Still unsure of the origin:
                # if I dont explicitely close the QueryResult manually, it fails with an error regarding semaphores.
                # Might be due to the order in which the connection and the QueryResult are closed when holding variables go out-of-scope.
                if result:
                    if isinstance(result, kuzu.QueryResult):
                        try:
                            result.close()
                        except Exception as ex:
                            logger.exception("Failed to close the query result.")
                    else:
                        for q in result:
                            try:
                                q.close()
                            except Exception as ex:
                                logger.exception("Failed to close one of the query result.")
        return graph

    def __execute_kuzu_query_df(self, query: str) -> DataFrame:
        result: kuzu.QueryResult | List[kuzu.QueryResult] | None = None
        with self.db_instance.get_new_conn() as conn:
            try:
                result = conn.connection.execute(query)
                if isinstance(result, kuzu.QueryResult):
                    df = result.get_as_df()
                else:
                    raise Exception("Unexpected kuzu result type.")
            except Exception as ex:
                logger.info(f"Failed to execute query {query}.")
                raise Exception(str(ex))
            finally:
                # Still unsure of the origin:
                # if I dont explicitely close the QueryResult manually, it fails with an error regarding semaphores.
                # Might be due to the order in which the connection and the QueryResult are closed when holding variables go out-of-scope.
                if result:
                    if isinstance(result, kuzu.QueryResult):
                        try:
                            result.close()
                        except Exception as ex:
                            logger.exception("Failed to close the query result.")
                    else:
                        for q in result:
                            try:
                                q.close()
                            except Exception as ex:
                                logger.exception("Failed to close one of the query result.")
        return df


class ExplorerColorizer:
    @staticmethod
    def hash_to_index(input_string: str, num_colors: int) -> int:
        """
        Hashes an input string and maps it to an integer index within a specified range.

        Args:
            input_string (str): The string to be hashed.
            num_colors (int): The upper bound (exclusive) for the index range.

        Returns:
            int: An integer index in the range [0, num_colors), determined by the hash of the input string.
        """
        hash_object = hashlib.md5(input_string.encode())
        hash_int = int(hash_object.hexdigest(), 16)
        index = hash_int % num_colors
        return index

    @staticmethod
    def string_to_beautiful_color(input_string: str) -> str:
        """
        Converts an input string to a visually distinct color from a predefined palette.

        This function deterministically maps the given string to one of several aesthetically pleasing colors,
        ensuring that the same string always yields the same color. Useful for consistently coloring items
        (such as nodes) in visualizations.

        Args:
            input_string (str): The string to be mapped to a color.

        Returns:
            str: A hex color code selected from the palette.
        """
        color_palette = [
            "#1f77b4",  # muted blue
            "#ff7f0e",  # safety orange
            "#2ca02c",  # cooked asparagus green
            "#d62728",  # brick red
            "#9467bd",  # muted purple
            "#8c564b",  # chestnut brown
            "#e377c2",  # raspberry yogurt pink
            "#7f7f7f",  # middle gray
            "#bcbd22",  # curry yellow-green
            "#17becf",  # blue-teal
        ]
        color_index = ExplorerColorizer.hash_to_index(input_string, len(color_palette))
        return color_palette[color_index]


class ExplorerMetadataManager:
    def __init__(self, db_folders: List[Folder]):
        self.db_folders = db_folders
        self.db_instance_factory = ExplorerDbInstanceFactory(db_folders)

    def get_snapshot_metadata(self, snapshot_id: str) -> ExplorerMetadata:
        """
        Retrieves the metadata for a given saved configuration.

        If the metadata for the specified snapsaved configurationshot ID is already available, it is returned directly.
        Otherwise, a new database instance is created for the saved configuration, and the metadata is built using
        the ExplorerMetadataBuilder.

        Args:
            snapshot_id (str): The unique identifier of the saved configuration for which metadata is requested.

        Returns:
            ExplorerMetadata: The metadata associated with the specified saved configuration.
        """

        if (graph_metadata := self._get_snapshot_data(snapshot_id)) is None:
            db_instance = self.db_instance_factory.get_db_instance_from_snapshot_id(snapshot_id)
            graph_metadata = ExplorerMetadataBuilder(db_instance, snapshot_id).build()

        return graph_metadata

    def get_snapshots_metadata(self) -> List[ExplorerMetadata]:
        """
        Retrieves metadata for all available saved configurations in the database folder.

        This method scans the database folder for file paths containing "/db/", extracts unique saved configuration IDs from those paths,
        and then gathers metadata for each identified saved configuration.

        Returns:
            List[ExplorerMetadata]: A list of metadata objects corresponding to each discovered saved configuration.
        """
        snapshot_list = set()
        for folder in self.db_folders:
            ids = self._get_snapshot_id_list(folder)
            snapshot_list.update(ids)

        snapshots = [self.get_snapshot_metadata(snapshot_id) for snapshot_id in snapshot_list]
        return snapshots

    def _get_snapshot_data(self, snapshot_id: str) -> Optional[ExplorerMetadata]:
        """
        Retrieves and parses saved configuration metadata from a JSON file for a given saved configuration ID.

        The method constructs the path to the saved configuration JSON file, reads its contents, and extracts
        relevant metadata fields such as saved configuration ID, name, comment, epoch timestamp, node and edge
        definitions, view configurations, and cypher queries. If any error occurs during reading or
        parsing, the method returns None.

        Args:
            snapshot_id (str): The unique identifier of the saved configuration to retrieve.

        Returns:
            Optional[ExplorerMetadata]: An ExplorerMetadata object containing the parsed saved configuration data
            if successful, or None if the data could not be retrieved or parsed.
        """
        relative_path_to_db = Path(PATH_TO_SNAPSHOT_JSON) / snapshot_id / SNAPSHOT_JSON_FILE_NAME
        build_info_path_to_db = Path(PATH_TO_SNAPSHOT_JSON) / snapshot_id / BUILD_INFO_JSON_FILE_NAME
        db_folder_containing_snapshot = self.db_instance_factory.get_db_folder_by_snapshot_id(snapshot_id)
        try:
            snapshot_data = db_folder_containing_snapshot.read_json(relative_path_to_db)
            build_info_data = db_folder_containing_snapshot.read_json(build_info_path_to_db)
            return ExplorerMetadata(
                snapshot_id=snapshot_data["id"],
                name=snapshot_data.get("name", ""),
                comment=snapshot_data.get("comment", ""),
                epoch_ms=build_info_data.get("epoch_ms", 0),
                node_definitions=[
                    NodeGroupDefinition(
                        node_group=node["node_group"],
                        node_id=node["node_id"],
                        definition_id=node_def["definition_id"],
                        source_dataset=node_def["source_dataset"],
                        primary_col=node_def["primary_col"],
                        label_col=node_def["label_col"],
                        property_list=node_def["property_list"],
                        filters_association=node_def["filters_association"],
                        filters_stored=node_def["filters_stored"],
                    )
                    for node in snapshot_data.get("nodes", {}).values()
                    for node_def in node.get("definitions", [])
                ],
                edge_definitions=[
                    EdgeGroupDefinition(
                        source_node_group=snapshot_data["nodes"][edge["source_node_id"]]["node_group"],
                        target_node_group=snapshot_data["nodes"][edge["target_node_id"]]["node_group"],
                        source_node_id=edge["source_node_id"],
                        target_node_id=edge["target_node_id"],
                        edge_group=edge["edge_group"],
                        edge_id=edge["edge_id"],
                        definition_id=edge_def["definition_id"],
                        edge_dataset=edge_def["edge_dataset"],
                        source_column=edge_def["source_column"],
                        target_column=edge_def["target_column"],
                        property_list=edge_def["property_list"],
                        filters_association=edge_def["filters_association"],
                        filters_stored=edge_def["filters_stored"],
                    )
                    for edge in snapshot_data.get("edges", {}).values()
                    for edge_def in edge.get("definitions", [])
                ],
                nodes_view=snapshot_data.get("nodes_view", {}),
                edges_view=snapshot_data.get("edges_view", {}),
                cypher_queries=snapshot_data.get("cypher_queries", []),
            )
        except Exception as e:
            logger.error(f"Error reading saved configuration metadata file {str(e)}")
            return None

    @staticmethod
    def _get_snapshot_id_list(folder: Folder) -> set[GraphId]:
        return set(
            [
                ExplorerMetadataManager._get_snapshot_id_in_path(p)
                for p in folder.list_paths_in_partition()
                if "db.kz" in p
            ]
        )

    @staticmethod
    def _get_snapshot_id_in_path(path: str) -> str:
        """
        Extracts the saved configuration ID from a given path string.

        The function splits the input path by '/' and iterates through each part.
        It returns the last part encountered before the "db" subfolder.
        If "db" is not found in the path, an exception is raised.

        Args:
            path (str): The file path from which to extract the saved configuration ID.

        Returns:
            str: The saved configuration ID found before the "db" subfolder.

        Raises:
            Exception: If the "db" subfolder is not present in the path.
        """
        parts = path.split("/")
        result = ""
        for part in parts:
            if part == "db.kz":
                return result
            result = part
        raise Exception("Missing Kuzu database file (db.kz).")
