import os
from typing import Any, Dict
from uuid import uuid4

from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from typing import Dict, Optional, Any

from dataiku.base.utils import package_is_at_least
from dataiku.core.vector_stores.dku_vector_store import DkuLocalVectorStore, UpdateMethod, logger, DkuVectorStore
from dataiku.core.vector_stores.vector_store_document_filter import FAISSVectorStoreDocumentFilter
from dataiku.langchain.dataset_loader import VectorStoreLoader
from dataiku.llm.types import RetrievableKnowledge, BaseVectorStoreQuerySettings
from dataikuapi.dss.langchain import DKUEmbeddings


class FAISSVectorStore(DkuLocalVectorStore):

    def __init__(self, kb: RetrievableKnowledge, exec_folder: str):
        super(FAISSVectorStore, self).__init__(kb, exec_folder, collection_name="")  # FAISS doesn't have a concept of collections
        self.document_filter = FAISSVectorStoreDocumentFilter(self.metadata_column_type_and_meaning)
        self.document_filter.NEQ_OPERATOR = "$neq"

    def get_db(self, embeddings: DKUEmbeddings, allow_creation: bool = False, **kwargs: Any) -> FAISS:
        # FAISS doesn't have a concept of collections, but we need to create the files if they don't exist
        if not os.path.exists(os.path.join(self.exec_folder, "index.faiss")):  # todo should check for allow_creation first
            doc_id = str(uuid4())
            vectorstore = FAISS.from_documents([Document(page_content="Example document")], embeddings, ids=[doc_id])
            vectorstore.delete(ids=[doc_id])
            vectorstore.save_local(self.exec_folder)

        import langchain_community
        if package_is_at_least(langchain_community, "0.0.27"):
            db = FAISS.load_local(self.exec_folder, embeddings, allow_dangerous_deserialization=True)
        else:
            db = FAISS.load_local(self.exec_folder, embeddings)

        return db

    def load_documents(self, documents_loader: VectorStoreLoader, embeddings: DKUEmbeddings, update_method:UpdateMethod=UpdateMethod.OVERWRITE) -> FAISS:
        vector_store = super(FAISSVectorStore, self).load_documents(documents_loader, embeddings, update_method)
        assert isinstance(vector_store, FAISS), "Expected FAISS vector store"
        vector_store.save_local(self.exec_folder)
        logger.info("Saved FAISS vector store to disk")
        return vector_store

    def load_documents_to_add(self, documents_loader:VectorStoreLoader, embeddings: DKUEmbeddings, to_add_documents_uuids:  Optional[list[str]]=None) -> FAISS:
        vector_store = super(FAISSVectorStore, self).load_documents_to_add(documents_loader, embeddings, to_add_documents_uuids)

        if vector_store is None:
            return None # vector store not modified

        assert isinstance(vector_store, FAISS), "Expected FAISS vector store"
        vector_store.save_local(self.exec_folder)
        logger.info("Saved FAISS vector store to disk")
        return vector_store

    def load_documents_to_delete(self, embeddings: DKUEmbeddings, to_delete_documents_uuids:  Optional[list[str]]=None) -> FAISS:
        vector_store = super(FAISSVectorStore, self).load_documents_to_delete(embeddings, to_delete_documents_uuids)

        if vector_store is None:
            return None # vector store not modified

        assert isinstance(vector_store, FAISS), "Expected FAISS vector store"
        vector_store.save_local(self.exec_folder)
        logger.info("Saved FAISS vector store to disk")
        return vector_store

    def clear_files(self, folder_path: str) -> None:
        if os.path.isfile(os.path.join(folder_path, "index.faiss")):
            os.remove(os.path.join(folder_path, "index.faiss"))

        if os.path.isfile(os.path.join(folder_path, "index.pkl")):
            os.remove(os.path.join(folder_path, "index.pkl"))

        logger.info("Cleared FAISS vector store files at {}".format(folder_path))

    def get_file_size(self) -> int:
        size = 0
        file_paths = [os.path.join(self.exec_folder, "index.faiss"), os.path.join(self.exec_folder, "index.pkl")]

        for file_path in file_paths:
            if os.path.isfile(file_path):
                size += os.path.getsize(file_path)
            else:
                logger.warning("File {} not found while calculating FAISS vector store size".format(file_path))

        return size

    @staticmethod
    def _get_search_kwargs(query_settings: BaseVectorStoreQuerySettings) -> Dict:
        kwargs = DkuLocalVectorStore._get_search_kwargs(query_settings)
        # override the default (20) to allow retrieving more documents before applying post-filtering
        # see https://github.com/langchain-ai/langchain/blob/v0.1.16/libs/community/langchain_community/vectorstores/faiss.py#L290-L291
        kwargs["fetch_k"] = kwargs["k"]
        return kwargs
