import logging
import threading
import weakref
from typing import TYPE_CHECKING, Tuple, Dict, Optional, List

from dataiku.core.vector_stores.lifecycle.shared_folder import (
    load_into_shared_folder,
    remove_unused_versions_except_latest,
    VectorStoreSharedFolder
)

if TYPE_CHECKING:
    from langchain_core.vectorstores import VectorStore

# Extra typing
LockKey = Tuple[str, str]        # (project_key, knowledge_bank_id)
CacheKey = Tuple[str, str, str]  # (project_key, knowledge_bank_id, version)

logger = logging.getLogger(__name__)


class LangchainVectorStoresCache:
    """
    Caches the Langchain Vectorstore instances, by
    - (project key, knowledge bank id, version)
    - vectorstore_kwargs

    Caching relies on the order in which the vector store kwargs are provided.
    """

    def __init__(self):
        # protect the cache state
        self._lock = threading.RLock()
        self._kb_locks: Dict[LockKey, threading.RLock] = dict()

        # (1) A single cache item holds 1 knowledge bank version, with its
        # shared folder. It is downloaded **once**.
        # (2) A cache item retains multiple Langchain VectorStore instances, one
        # instance per set of **vectorstore_kwargs.
        self._cache: Dict[CacheKey, _CacheItem] = dict()

    def _kb_lock_for(self, project_key: str, knowledge_bank_id: str) -> threading.RLock:
        with self._lock:
            key = (project_key, knowledge_bank_id)
            if key not in self._kb_locks:
                # reentrancy is important here as finalizers may be called
                # during eviction (cf get_or_create)
                self._kb_locks[key] = threading.RLock()

            return self._kb_locks[key]

    def _get(self, key: CacheKey) -> Optional['_CacheItem']:
        with self._lock:
            return self._cache.get(key)

    def _set(self, key: CacheKey, item: '_CacheItem'):
        with self._lock:
            self._cache[key] = item

    def _discard(self, key: CacheKey):
        with self._lock:
            self._cache.pop(key, None)

    def _list_other_version_items(self, latest_key: CacheKey) -> List['_CacheItem']:
        (p_key, kb_id, v) = latest_key

        with self._lock:
            return list(
                item for key, item in self._cache.items()
                if key[0] == p_key and key[1] == kb_id and key[2] != v
            )

    def get_or_create(self, project_key: str, knowledge_bank_id: str, current_version: str, vectorstore_kwargs):
        key = (project_key, knowledge_bank_id, current_version)
        # locking on all versions for a given knowledge bank allows for
        # easier maintenance -- notably: evicting older versions
        with self._kb_lock_for(project_key, knowledge_bank_id):
            item = self._get(key)

            if item is None:
                shared_folder = load_into_shared_folder(
                    project_key, knowledge_bank_id, current_version)

                # try to clean previous unused cached versions, if any
                remove_unused_versions_except_latest(
                    project_key, knowledge_bank_id)

                item = _CacheItem(shared_folder)
                self._set(key, item)

                for outdated_item in self._list_other_version_items(key):
                    outdated_item.invalidate()

            # reuse cached vector store, or create a new one
            langchain_vs, created = item.get_or_create(vectorstore_kwargs)

            if created:
                # when this langchain_vs is garbage collected,
                # try to release the corresponding cache item
                weakref.finalize(langchain_vs, self._try_release, key)

            return langchain_vs

    def _try_release(self, key: CacheKey):
        (project_key, knowledge_bank_id, _) = key

        with self._kb_lock_for(project_key, knowledge_bank_id):
            item = self._get(key)

            if item and item.try_release():
                self._discard(key)


class _CacheItem:

    def __init__(self, shared_folder: VectorStoreSharedFolder):
        self._shared_folder = shared_folder
        self._vector_stores: Dict[Tuple, 'VectorStore'] = dict()
        self._refcount = 0

    def __repr__(self):
        return "CacheItem(refcount={}, folder_path={})".format(
            self._refcount, self._shared_folder.folder_path)

    def get_or_create(self, vectorstore_kwargs) -> Tuple['VectorStore', bool]:
        key = _make_key(vectorstore_kwargs)
        use_cache = True

        try:
            langchain_vs = self._vector_stores.get(key)

            if langchain_vs:
                return langchain_vs, False
        except TypeError:
            # the key could not be hashed, most likely because there is a dict
            # in the kwargs --> do not cache in this case
            use_cache = False

        langchain_vs = self._shared_folder.create_langchain_vectorstore(**vectorstore_kwargs)
        self._refcount += 1

        if use_cache:
            self._vector_stores[key] = langchain_vs

        return langchain_vs, True

    def try_release(self):
        """
        Returns whether the item is still used.
        When no longer used, attempt at removing the underlying storage too.

        :rtype: bool
        """
        self._refcount -= 1
        if self._refcount > 0:
            return False  # still in use

        self._shared_folder.remove_unless_used()
        return True

    def invalidate(self):
        """
        Mark this item as outdated.
        """
        self._shared_folder.set_outdated()
        # evicting all mappings to earlier versions allows to garbage
        # collect the corresponding vector store instances, thus
        # triggering their finalizer code
        self._vector_stores.clear()


# This function attempts at building a consistent cache key for given kwargs.
# It draws inspiration from Python's way of creating keys for caching calls
# decorated with `@functools.cache`.
#
# For more details, please refer to
# https://github.com/python/cpython/blob/4f8bb3947cfbc20f970ff9d9531e1132a9e95396/Lib/functools.py#L487-L495
def _make_key(kwargs):
    key = tuple()
    if kwargs:
        for item in kwargs.items():
            key += item

    return key


LANGCHAIN_VECTOR_STORES_CACHE = LangchainVectorStoresCache()
