from __future__ import annotations

import json
import logging
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TypedDict

import kuzu
from dataiku import Folder
from diskcache import Index

from editor.backend.utils.store import get_or_create_kuzu_db_instance_access_diskcache_dir
from editor.backend.utils.webapp_config import webapp_config
from solutions.graph.dataiku.utils import S3Credentials
from solutions.graph.kuzu.connection_manager import KuzuConnectionContextManager

from .models import GraphDBReadInProgressError, GraphDBWriteInProgressError, GraphId

SUPPORTED_STORAGES = ("S3", "Filesystem")


logger = logging.getLogger(__name__)


class _DbAccessState(TypedDict):
    db_path: Path | str | None
    read_handles_count: int
    write_locked: bool


class AbstractDbInstance(ABC):
    """
    Manages safe concurrent access to a kuzu database via disk-backed access state.

    Ensures:
      - Multiple concurrent readers allowed.
      - Single writer exclusive access.
      - Prevents runtime conflicts in multi-process environments.
    """

    def __init__(self, readonly: bool) -> None:
        self.readonly = readonly

        diskcache_path = get_or_create_kuzu_db_instance_access_diskcache_dir()
        logger.debug(f"Initializing kuzu db instance access index with diskcache at {diskcache_path}.")
        self.kuzu_access_index = Index(diskcache_path)

    def __enter__(self) -> AbstractDbInstance:
        """
        Raises:
            GraphDBWriteInProgressError
            GraphDBReadInProgressError
        """
        self._open()
        return self

    def __exit__(self, exception_type, exception_value, exception_traceback) -> None:
        """
        Raises:
            RuntimeError
        """
        self._close()

    def _open(self) -> None:
        with self.kuzu_access_index.transact():
            state: _DbAccessState | None = self.kuzu_access_index.get(self.get_db_path())

            if state:
                if state["write_locked"]:
                    # As long as it is write locked, it is not usable.
                    raise GraphDBWriteInProgressError()

                if not self.readonly and state["read_handles_count"] > 0:
                    # If we want a write lock but there are reads happening already, we cannot write the graph.
                    raise GraphDBReadInProgressError()

                state = _DbAccessState(**state)
                if self.readonly:
                    state["read_handles_count"] += 1
                else:
                    state["write_locked"] = True
            else:
                # First access
                state = _DbAccessState(
                    db_path=self.get_db_path(),
                    read_handles_count=1 if self.readonly else 0,
                    write_locked=not self.readonly,
                )

            self.kuzu_access_index[self.get_db_path()] = state
            logger.debug(f"Opening AbstractDbInstance, state is now {state}.")

    def _close(self) -> None:
        with self.kuzu_access_index.transact():
            state: _DbAccessState | None = self.kuzu_access_index.get(self.get_db_path())
            if not state:
                raise RuntimeError("Should not happen")

            state = _DbAccessState(**state)
            if self.readonly:
                if state["read_handles_count"] == 0:
                    raise RuntimeError("Tried to release a read lock when none were held.")

                state["read_handles_count"] = max(0, state["read_handles_count"] - 1)
            else:
                state["write_locked"] = False

            self.kuzu_access_index[self.get_db_path()] = state
            logger.debug(f"Closing AbstractDbInstance, state is now {state}.")

    @abstractmethod
    def get_db_path(self) -> Path:
        pass

    @abstractmethod
    def create(self) -> None:
        pass

    @abstractmethod
    def exists(self) -> bool:
        pass

    @abstractmethod
    def get_new_conn(self, timeout_seconds: int | None = None) -> KuzuConnectionContextManager:
        pass

    @abstractmethod
    def destroy(self) -> None:
        pass


class LocalDbInstance(AbstractDbInstance):
    def __init__(self, db_path: Path | str, readonly: bool) -> None:
        super().__init__(readonly)

        self.__db_path = Path(db_path)
        root_path = self.__db_path.parent
        if not root_path.exists():
            root_path.mkdir(parents=True, exist_ok=True)

    def get_db_path(self) -> Path:
        return self.__db_path

    def create(self) -> None:
        if self.readonly:
            raise RuntimeError("Cannot create a readonly database instance.")

        conn = kuzu.Database(self.get_db_path())
        conn.close()

    def exists(self) -> bool:
        return os.path.exists(self.get_db_path())

    def get_new_conn(self, timeout_seconds: int | None = None) -> KuzuConnectionContextManager:
        # We create one database instance for each query because of this issue:
        # https://github.com/kuzudb/kuzu/issues/2934
        def create_connection(db: kuzu.Database) -> kuzu.Connection:
            conn = kuzu.Connection(db)
            if timeout_seconds is not None:
                conn.execute(
                    f"""
                        CALL TIMEOUT={timeout_seconds * 1000};
                    """
                )
            return conn

        return KuzuConnectionContextManager(self.get_db_path(), self.readonly, create_connection)

    def destroy(self) -> None:
        if self.readonly:
            raise RuntimeError("Cannot destroy a readonly database instance.")

        os.unlink(self.get_db_path())


class EditorWebAppDbInstance(LocalDbInstance):
    def __init__(self, graph_id: GraphId, readonly: bool) -> None:
        db_path = Path(os.path.join(webapp_config.db_folder_path, graph_id, "db.kz"))
        super().__init__(db_path, readonly)


class S3RemoteDbInstance(AbstractDbInstance):
    def __init__(self, credentials: S3Credentials, prefix: Path, readonly: bool) -> None:
        super().__init__(readonly)

        self.__full_prefix = Path(os.path.join(credentials["root_prefix"], prefix))
        self.__s3_url = f"s3://{credentials['bucket']}{self.__full_prefix}"
        self.__credentials = credentials
        logger.debug(f"Connecting to {self.__s3_url}.")

    def get_db_path(self) -> Path:
        return self.__full_prefix

    def create(self) -> None:
        raise Exception("Not implemented.")

    def exists(self) -> bool:
        conn_context_manager: KuzuConnectionContextManager | None = None
        try:
            conn_context_manager = self.get_new_conn()
            return True
        except RuntimeError as ex:
            logger.debug(f"Failed to connect to DB on S3 {self.__s3_url}.", exc_info=ex)
            return False
        finally:
            if conn_context_manager:
                conn_context_manager.close()
                conn_context_manager = None

    def get_new_conn(self, timeout_seconds: int | None = None) -> KuzuConnectionContextManager:
        def create_connection(db: kuzu.Database) -> kuzu.Connection:
            conn = kuzu.Connection(db)

            region = self.__credentials["region"]
            session_token = self.__credentials["session_token"]
            access_key_id = self.__credentials["access_key_id"]
            secret_access_key = self.__credentials["secret_access_key"]

            conn.execute(
                f"""
                    INSTALL httpfs;
                    LOAD EXTENSION httpfs;
                    CALL HTTP_CACHE_FILE=TRUE;
                    CALL s3_region='{region}';
                    CALL s3_access_key_id='{access_key_id}';
                    CALL s3_secret_access_key='{secret_access_key}';
                    {f"CALL s3_session_token='{session_token}';" if session_token else ""}
                    ATTACH '{self.__s3_url}' AS vge (dbtype kuzu);
                    {f'CALL TIMEOUT={timeout_seconds * 1000};' if timeout_seconds is not None else ''}
                """
            )
            return conn

        # In-memory database cannot be opened as readonly: https://docs.kuzudb.com/concurrency/#faqs
        return KuzuConnectionContextManager(None, False, create_connection)

    def destroy(self) -> None:
        raise Exception("Not implemented.")


class LocalReplicaDbInstance(LocalDbInstance):
    def __init__(self, db_path: Path | str, db_folder: Folder, readonly: bool) -> None:
        """
        Initialize a local replica of the database under a temporary directory.
        Downloads `buildInfo.json` and the DB file only when needed:
          - if files are missing locally, or
          - if the remote `buildInfo.json` differs from the local copy.
        """
        db_path = Path(db_path)
        temp_root = Path(tempfile.gettempdir())

        subdir = db_path.parent
        self.__temp_dir = (temp_root / subdir).resolve()
        self.__temp_dir.mkdir(parents=True, exist_ok=True)

        build_info_name = "buildInfo.json"
        remote_build_info_path = subdir / build_info_name
        local_build_info_path = self.__temp_dir / build_info_name

        need_download = self._sync_build_info(local_build_info_path, remote_build_info_path, db_folder)

        local_db_file_path = self.__temp_dir / db_path.name
        if need_download or not local_db_file_path.exists():
            logger.info(f"Downloading {str(db_path)} to {str(local_db_file_path)}...")
            with db_folder.get_download_stream(str(db_path)) as remote_file, local_db_file_path.open("wb") as out:
                out.write(remote_file.read())
        else:
            logger.info(f"Using cached DB at {str(local_db_file_path)} (up to date).")

        super().__init__(local_db_file_path, readonly)

    @staticmethod
    def _safe_load_json(path: Path) -> dict | None:
        """Load JSON from `path`; return None if the file doesn't exist or is invalid."""
        if not path.exists():
            return None
        try:
            with path.open("r", encoding="utf-8") as f:
                return json.load(f)  # type: ignore
        except (OSError, json.JSONDecodeError):
            return None

    @staticmethod
    def _sync_build_info(local_path: Path, remote_path: Path, db_folder: Folder) -> bool:
        """
        Ensure local `buildInfo.json` matches remote.
        Returns True if the DB should be (re)downloaded (missing or changed).
        """
        with db_folder.get_download_stream(str(remote_path)) as remote_file:
            remote_bytes = remote_file.read()
        try:
            remote_info = json.loads(remote_bytes)
            logger.debug(f"Folder epoc {remote_info}")
        except json.JSONDecodeError:
            remote_info = None

        local_info = LocalReplicaDbInstance._safe_load_json(local_path)
        logger.debug(f"Local epoc {local_info} {str(local_path)}")

        changed = (local_info is None) or (remote_info is None) or (remote_info != local_info)

        try:
            with local_path.open("wb") as f:
                f.write(remote_bytes)
        except OSError:
            changed = True

        return changed
