import asyncio
import logging
import os
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Dict

import dataiku
from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core.knowledge_bank import KnowledgeBank
from dataiku.core.vector_stores.dku_vector_store import VectorStoreFactory, DkuVectorStore
from dataiku.langchain.base_rag_handler import DKU_TEXT_EMBEDDING_COLUMN, AugmentationFallbackStrategy
from dataiku.langchain.content_part_types import ImageRetrieval, ImageRefPart, CaptionedImageRefPart
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.document_handler import RetrievalSource
from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT, DKU_DOCUMENT_INFO
from dataiku.langchain.multimodal_content import from_doc
from dataiku.langchain.sources_handler import SourcesHandler
from dataiku.llm.rag.reranking import Reranking, DocumentWithScores, ScoreType
from dataiku.llm.tracing import new_trace
from dataiku.llm.types import VectorStoreQueryToolParams, RetrievableKnowledge, TrustedObject, RAGRerankingSettings

logger = logging.getLogger("rag_query_server")

class VectorStoreQueryToolServer:
    embeddings: DKUEmbeddings
    kb: RetrievableKnowledge
    retrieval_source: RetrievalSource
    toolRef: str
    dku_vector_store: DkuVectorStore
    run_counter: int = 1
    trusted_object: TrustedObject

    def __init__(self):
        self.started = False
        self.retriever = None
        self.run_counter = 1

        # While VSQTS exposes an async API, it is currently backed a ThreadPoolExecutor under the hood, because
        # some of underlying APIs are synchronous (e.g. DKUEmbeddings, DSSLLM). The nb. of threads in the pool can be
        # much higher than the number of cores because each thread spends most of the time waiting for I/O.
        self.executor = ThreadPoolExecutor(64)

        # The lock is used to serialize all accesses to the vector store. It is not clear if that is necessary, and
        # it may depend on the specific vector store implementation. This may have a performance impact
        self.lock = threading.Lock()

    def start(self, start_command):
        assert not self.started, "Already started"

        kb_full_id = start_command["knowledgeBankFullId"]
        self.toolRef = start_command["toolRef"]
        knowledge_bank = KnowledgeBank(kb_full_id, context_project_key=dataiku.default_project_key())
        # the isolated folder will be cleaned up when the kernel dies
        self.trusted_object: TrustedObject = {"smartRef": self.toolRef, "type": "AGENT_TOOL"}
        isolated_folder = knowledge_bank.load_into_isolated_folder(trusted_object=self.trusted_object)
        self.kb = isolated_folder.get_kb_desc()

        self.retrieval_source = RetrievalSource.EMBEDDING
        # TODO is there a retrieval column ??

        #self.retrieval_source = RetrievalSource(self.rag_settings.get("retrievalSource", RetrievalSource.EMBEDDING))
        #self.retrieval_column = self.rag_settings.get("retrievalColumn") if self.retrieval_source == RetrievalSource.CUSTOM else None
        self.embeddings = DKUEmbeddings(llm_id=self.kb["embeddingLLMId"])

        self.dku_vector_store = VectorStoreFactory.get_vector_store(self.kb, isolated_folder.folder_path, VectorStoreFactory.get_connection_details_from_env)
        self.started = True

    def process_tool_input(self, command: Dict) -> Dict:
        assert self.started, "Not started"

        trace = new_trace("DKU_VECTOR_STORE_SEARCH")
        trace.__enter__()

        query = command["searchQuery"]
        query_settings: VectorStoreQueryToolParams = command.get("params", {})

        augmentation_fallback_strategy: AugmentationFallbackStrategy = AugmentationFallbackStrategy(
            query_settings.get("augmentationFallbackStrategy", AugmentationFallbackStrategy.USE_EMBEDDING))
        logger.info("Using fallback strategy: %s" % augmentation_fallback_strategy)

        include_multimodal_content = query_settings.get("includeMultimodalContent", False)

        sources_handler = SourcesHandler(
            query_settings.get("sourcesSettings", {}),
            full_folder_id=self.kb.get("managedFolderId"),
            retrieval_columns=query_settings.get("retrievalColumns", None),
            # Use fallback strategy only if no multimodal output, else retrieve all sources to be back compatible
            augmentation_fallback_strategy=augmentation_fallback_strategy if include_multimodal_content else AugmentationFallbackStrategy.USE_EMBEDDING,
            trusted_object=self.trusted_object
        )
        reranking_settings: RAGRerankingSettings = command.get('params', {}).get("reranking", {})

        logger.info("Asking this question to the vector store: %s" % query)
        logger.info("Full command: %s" % command)
        with self.lock:

            additional_search_kwargs: Dict = {}

            if "enforceDocumentLevelSecurity" in query_settings and query_settings["enforceDocumentLevelSecurity"]:
                self.dku_vector_store.add_security_filter(additional_search_kwargs, command["securityTokens"])

            if "filter" in command:
                self.dku_vector_store.add_filter(additional_search_kwargs, command["filter"])

            self.dku_vector_store.add_dynamic_filter(additional_search_kwargs, command.get("callerFilters", []), self.toolRef)

            if query_settings.get("allowEmptyQuery", False) and (not query or not query.strip()):
                query = "test"

            if self.dku_vector_store.include_similarity_score(query_settings):
                search_result = self.dku_vector_store.search_with_scores(query, self.embeddings, query_settings, additional_search_kwargs)
                docs = [DocumentWithScores(document, score) for (document, score) in search_result] if search_result else []
            else:
                retriever = self.dku_vector_store.as_retriever(self.embeddings, query_settings, additional_search_kwargs)
                docs = [DocumentWithScores(document, None) for document in retriever.invoke(query)]

            trace.attributes["nbRetrievedDocuments"] = len(docs)

            # See comment in DKUEmbeddings about how to retrieve the trace from the last call
            if hasattr(self.embeddings, "_last_trace") and hasattr(self.embeddings._last_trace, "trace"):
                trace.append_trace(self.embeddings._last_trace.trace)
            else:
                logging.info("No last trace found in DKUEmbeddings after rag document retrieval")

        logger.info("Retrieved %s documents" % len(docs))

        if reranking_settings.get("enabled", False):
            reranker = Reranking(reranking_settings)
            docs = reranker.rerank_with_scores(query, docs, trace, logger)

        ret = {"documents": []}
        parts = []
        source = {
            # TODO - Generate in Java - "toolCallDescription": "Scraped the page %s" % url,

            "items": []
        }
        for i, doc in enumerate(docs):
            retrieval_columns = query_settings.get("retrievalColumns", [])
            prepared_doc = {"metadata": {}}
            for m in retrieval_columns:
                if m == DKU_TEXT_EMBEDDING_COLUMN:
                    prepared_doc["text"] = doc.page_content
                else:
                    prepared_doc["metadata"][m] = self.dku_vector_store.get_document_metadata(doc, m)
            if include_multimodal_content:
                multimodal_content = doc.metadata.get(DKU_MULTIMODAL_CONTENT, None)
                if multimodal_content is not None:
                    logger.info("part multimodal content")
                    prepared_doc["metadata"][DKU_MULTIMODAL_CONTENT] = multimodal_content
                    doc_parts = self.get_parts_from_multimodal_content(doc, augmentation_fallback_strategy)
                    if len(doc_parts) == 0:
                        continue  # skip the document
                    parts.extend(doc_parts)
                document_info = doc.metadata.get(DKU_DOCUMENT_INFO, None)
                if document_info is not None:
                    prepared_doc["metadata"][DKU_DOCUMENT_INFO] = document_info

            ret["documents"].append(prepared_doc)

            if query_settings.get("includeScore", False) and doc.get_score(ScoreType.RETRIEVAL) is not None:
                ret["documents"][i]["score"] = doc.get_score(ScoreType.RETRIEVAL)
            if reranking_settings.get("enabled", False) and doc.get_score(ScoreType.RERANKING) is not None:
                ret["documents"][i]["reranking_score"] = doc.get_score(ScoreType.RERANKING)

            source_item = sources_handler.build_role_based_source_from(doc)
            if source_item is not None:
                source["items"].append(source_item)

        trace.__exit__(None, None, None)

        return {"output": ret, "parts": parts, "trace": trace.to_dict(), "sources": [source]}

    async def handler(self, command):
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "run-tool":
            logger.info("\n===============  Start running tool - run %s ===============", self.run_counter)
            try:
                logger.debug("Received command: %s", command)
                yield await asyncio.get_running_loop().run_in_executor(
                    self.executor, self.process_tool_input, command
                )
                logger.info("\n=============== End running tool - run %s ===============", self.run_counter)
            finally:
                self.run_counter += 1
        else:
            raise Exception("Unknown command type: %s" % command["type"])

    def get_parts_from_multimodal_content(self, doc, augmentation_fallback_strategy):
        """ Turn parts from a document's multimodal content field into a list of tool output parts (see `LLMClient.ToolOutputPart`)
        """
        multimodal_content = from_doc(doc)
        if multimodal_content is None:
            return []

        parts = []
        if multimodal_content.type == "text":
            for text_part in multimodal_content.get_parts(None, None, None):
                parts.append({
                    "type": "TEXT",
                    "text": text_part.text,
                })
        elif multimodal_content.type in ("images", "captioned_images"):
            managed_folder_id = self.kb.get("managedFolderId")
            for image_part in multimodal_content.get_parts(None, ImageRetrieval.IMAGE_REF, managed_folder_id, self.trusted_object):
                if image_part.has_errors():
                    if augmentation_fallback_strategy == AugmentationFallbackStrategy.USE_EMBEDDING:
                        parts.append({
                            "type": "TEXT",
                            "text": doc.page_content,
                        })
                    elif augmentation_fallback_strategy == AugmentationFallbackStrategy.SKIP:
                        logger.warning("Skipping document from results: {err}.".format(err=image_part.error))
                        return []
                    else:
                        raise Exception("Failed to retrieve data for document : {err} ".format(err=image_part.error))
                elif type(image_part) in (ImageRefPart, CaptionedImageRefPart):
                    parts.append({
                        "type": "IMAGE_REF",
                        "folderId": image_part.full_folder_id,
                        "path": image_part.path,
                    })
        return parts

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )

if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = VectorStoreQueryToolServer()

        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
