import asyncio
import logging
import os
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core.knowledge_bank import KnowledgeBank
from dataiku.core.vector_stores.dku_vector_store import VectorStoreFactory
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.document_handler import RetrievalSource
from dataiku.langchain.sources_handler import SourcesHandler
from dataiku.llm.tracing import new_trace
from dataiku.llm.types import VectorStoreQueryToolParams, RetrievableKnowledge

logger = logging.getLogger("rag_query_server")


class VectorStoreQueryToolServer:
    #llm: DSSLLM
    #guardrails_llm_id: str
    embeddings: DKUEmbeddings
    #guardrails_embedding_model_id: str
    kb: RetrievableKnowledge
    retrieval_column: Optional[str]
    retrieval_source: RetrievalSource
    #embeddings_model_id: str
    #guardrails: List[Guardrail] = []
    #metrics_to_compute: List[str] = []

    def __init__(self):
        self.started = False
        self.retriever = None

        # While VSQTS exposes an async API, it is currently backed a ThreadPoolExecutor under the hood, because
        # some of underlying APIs are synchronous (e.g. DKUEmbeddings, DSSLLM). The nb. of threads in the pool can be
        # much higher than the number of cores because each thread spends most of the time waiting for I/O.
        self.executor = ThreadPoolExecutor(64)

        # The lock is used to serialize all accesses to the vector store. It is not clear if that is necessary, and
        # it may depend on the specific vector store implementation. This may have a performance impact
        self.lock = threading.Lock()

    def start(self, start_command):
        assert not self.started, "Already started"

        kb_full_id = start_command["knowledgeBankFullId"]
        knowledge_bank = KnowledgeBank(kb_full_id)
        # the isolated folder will be cleaned up when the kernel dies
        isolated_folder = knowledge_bank.load_into_isolated_folder()
        self.kb = isolated_folder.get_kb_desc()

        self.retrieval_source = RetrievalSource.EMBEDDING
        # TODO is there a retrieval column ??

        #self.retrieval_source = RetrievalSource(self.rag_settings.get("retrievalSource", RetrievalSource.EMBEDDING))
        #self.retrieval_column = self.rag_settings.get("retrievalColumn") if self.retrieval_source == RetrievalSource.CUSTOM else None
        self.embeddings = DKUEmbeddings(llm_id=self.kb["embeddingLLMId"])

        self.dku_vector_store = VectorStoreFactory.get_vector_store(self.kb, isolated_folder.folder_path)
        self.started = True

    def process_tool_input(self, command: Dict) -> Dict:
        assert self.started, "Not started"

        trace = new_trace("DKU_VECTOR_STORE_SEARCH")
        trace.__enter__()

        query = command["searchQuery"]
        query_settings: VectorStoreQueryToolParams = command.get("params", {})
        sources_handler = SourcesHandler(
            query_settings.get("sourcesSettings", {}),
            full_folder_id=self.kb.get("managedFolderId"),
        )

        logger.info("Asking this question to the vector store: %s" % query)
        logger.info("Full command: %s" % command)
        with self.lock:

            additional_search_kwargs: Dict = {}

            if "enforceDocumentLevelSecurity" in query_settings and query_settings["enforceDocumentLevelSecurity"]:
                self.dku_vector_store.add_security_filter(additional_search_kwargs, command["securityTokens"])

            if "filter" in command:
                self.dku_vector_store.add_filter(additional_search_kwargs, command["filter"])


            retriever = self.dku_vector_store.as_retriever(self.embeddings, query_settings, additional_search_kwargs)
            docs = retriever.invoke(query)

            trace.attributes["nbRetrievedDocuments"] = len(docs)

            # See comment in DKUEmbeddings about how to retrieve the trace from the last call
            if hasattr(self.embeddings, "_last_trace") and hasattr(self.embeddings._last_trace, "trace"):
                trace.append_trace(self.embeddings._last_trace.trace)
            else:
                logging.info("No last trace found in DKUEmbeddings after rag document retrieval")

        logger.info("Retrieved %s documents" % len(docs))

        ret = {"documents" : []}

        source = {
            # TODO - Generate in Java - "toolCallDescription": "Scraped the page %s" % url,

            "items" : []
        }

        for doc in docs:
            text = doc.page_content

            if query_settings.get("allMetadataInContext", True):
                metadata = doc.metadata
            else:
                metadata = {}
                for m in query_settings.get("metadataInContext", []):
                    metadata[m] = doc.metadata.get(m, None)

            ret["documents"].append({
                "text": text,
                "metadata" : metadata
            })

            source["items"].append(sources_handler.build_role_based_source_from(doc))

        trace.__exit__(None, None, None)

        return {"output": ret, "trace" : trace.to_dict(), "sources": [source] }

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "run-tool":
            logger.debug("Received command: %s", command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.process_tool_input, command
            )
        else:
            raise Exception("Unknown command type: %s" % command["type"])


def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = VectorStoreQueryToolServer()

        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
