import asyncio
import json
import logging
import math
import os
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union, TYPE_CHECKING, Tuple, Any, AsyncGenerator

import dataiku
from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core.knowledge_bank import MultipartContext, KnowledgeBank
from dataiku.core.vector_stores.dku_vector_store import VectorStoreFactory, DkuVectorStore
from dataiku.huggingface.types import CompletionSettings, ChatMessage
from dataiku.langchain.base_rag_handler import RetrievalSource
from dataiku.langchain.content_part_types import TextPart, ImageRefPart, ImageRetrieval, CaptionedImageRefPart, MultipartContent
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from dataiku.langchain.document_handler import DocumentHandler
from dataiku.langchain.sources_handler import SourcesHandler
from dataiku.llm.evaluation.genai_metrics_input import GenAIMetricInput
from dataiku.llm.rag.guardrails import AnswerRelevancy, Faithfulness, Guardrail, MultimodalAnswerRelevancy, MultimodalFaithfulness
from dataiku.llm.tracing import SpanBuilder, new_trace
from dataiku.llm.types import LLMErrorType, ProcessSinglePromptCommand, RAGLLMSettings, SimpleCompletionResponseOrError, Source, FilterDesc, \
    SearchInputStrategy, TrustedObject
from dataikuapi.dss.llm import DSSLLM, DSSLLMCompletionQuery, DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter

if TYPE_CHECKING:
    from langchain_core.documents import Document
    from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_1_10 import RagasMetricsComputer

logger = logging.getLogger("rag_query_server")

class RAGServer:
    llm: DSSLLM
    guardrails_llm_id: str
    embeddings: DKUEmbeddings
    guardrails_embedding_model_id: str
    rag_settings: RAGLLMSettings
    search_input_strategy: SearchInputStrategy
    no_retrieval_key: str
    retrieval_source: RetrievalSource
    document_handler: DocumentHandler
    sources_handler: SourcesHandler
    guardrails: List[Guardrail]
    metrics_to_compute: List[str]
    include_content_in_sources: bool
    print_sources: bool
    output_format: str
    filter: FilterDesc
    dku_vector_store: DkuVectorStore
    run_counter: int

    def __init__(self) -> None:
        self.started = False

        # While RAGServer exposes an async API, it is currently backed a ThreadPoolExecutor under the hood, because
        # some of underlying APIs are synchronous (e.g. DKUEmbeddings, DSSLLM). The nb. of threads in the pool can be
        # much higher than the number of cores because each thread spends most of the time waiting for I/O.
        self.executor = ThreadPoolExecutor(64)

        # The lock is used to serialize all accesses to the vector store. It is not clear if that is necessary, and
        # it may depend on the specific vector store implementation. This may have a performance impact, but in any
        # case this implementation is more efficient than the previous one, which was processing all prompts in a
        # single thread (including LLM calls)
        self.lock = threading.Lock()

        # Initializing defaults
        self.guardrails = []
        self.metrics_to_compute = []
        self.run_counter = 1

    def start(self, start_command: Dict) -> None:
        assert not self.started, "Already started"
        self.rag_settings = start_command["ragSettings"]

        kb_full_id = start_command["knowledgeBankFullId"]
        knowledge_bank = KnowledgeBank(kb_full_id, context_project_key=dataiku.default_project_key())
        # the isolated folder will be cleaned up when the kernel dies
        llm_ref = start_command["llmRef"]
        trusted_object: TrustedObject = {"smartRef": llm_ref["savedModelSmartId"], "type": "SAVED_MODEL"}
        isolated_folder = knowledge_bank.load_into_isolated_folder(trusted_object=trusted_object)
        kb = isolated_folder.get_kb_desc()

        self.rag_settings = start_command["ragSettings"]
        self.completion_settings = start_command["defaultCompletionSettings"]
        self.search_input_strategy = self.rag_settings["searchInputStrategySettings"]["strategy"]
        self.no_retrieval_key = start_command["noRetrievalKey"]

        self.embeddings = DKUEmbeddings(llm_id=kb["embeddingLLMId"])
        self.retrieval_source = RetrievalSource(self.rag_settings.get("retrievalSource", RetrievalSource.EMBEDDING))
        retrieval_columns = self.rag_settings.get("retrievalColumns") if self.retrieval_source != RetrievalSource.MULTIMODAL else None
        self.document_handler = DocumentHandler(kb, self.rag_settings, self.retrieval_source, retrieval_columns)
        self.sources_handler = SourcesHandler(
            self.rag_settings.get("sourcesSettings", {}),
            kb.get("managedFolderId"),
            self.retrieval_source,
            retrieval_columns
        )
        # Guardrails settings initialization
        if self.rag_settings["ragSpecificGuardrails"]["faithfulnessSettings"]["enabled"]:
            self.guardrails.append(Faithfulness(self.rag_settings["ragSpecificGuardrails"]["faithfulnessSettings"]))
        if self.rag_settings["ragSpecificGuardrails"]["relevancySettings"]["enabled"]:
            self.guardrails.append(AnswerRelevancy(self.rag_settings["ragSpecificGuardrails"]["relevancySettings"]))
        if self.rag_settings["ragSpecificGuardrails"].get("multimodalFaithfulnessSettings", {"enabled": False})["enabled"]:
            self.guardrails.append(MultimodalFaithfulness(self.rag_settings["ragSpecificGuardrails"]["multimodalFaithfulnessSettings"]))
        if self.rag_settings["ragSpecificGuardrails"].get("multimodalRelevancySettings", {"enabled": False})["enabled"]:
            self.guardrails.append(MultimodalAnswerRelevancy(self.rag_settings["ragSpecificGuardrails"]["multimodalRelevancySettings"]))
        self.metrics_to_compute = [guardrail.metric_key_name for guardrail in self.guardrails]
        self.guardrails_llm_id = self.rag_settings["ragSpecificGuardrails"].get("llmId", self.rag_settings["llmId"])
        self.guardrails_embedding_model_id = self.rag_settings["ragSpecificGuardrails"].get("embeddingModelId", kb["embeddingLLMId"])

        # Sources settings
        self.include_content_in_sources = self.rag_settings.get("includeContentInSources", True)
        self.print_sources = self.rag_settings.get("printSources", True)
        self.output_format = self.rag_settings.get("outputFormat", "TEXT")

        self.llm = (
            dataiku.api_client()
            .get_default_project()
            .get_llm(self.rag_settings["llmId"])
        )

        self.filter = start_command.get("filter")

        self.dku_vector_store = VectorStoreFactory.get_vector_store(kb, isolated_folder.folder_path, VectorStoreFactory.get_connection_details_from_env)
        self.started = True

    def process_prompt(self, process_prompt_command: ProcessSinglePromptCommand) -> SimpleCompletionResponseOrError:
        assert self.started, "Not started"

        trace: SpanBuilder = new_trace("DKU_BUILTIN_RAG")
        trace.__enter__()

        (text_messages, relevant_messages, question) = self._extract_messages_and_question(process_prompt_command)
        question = self._query_rewriter(relevant_messages, question, trace)
        docs = []
        if self._should_query_kb(question):
            docs = self._query_retriever(process_prompt_command, question, trace)

        completion = self._get_completion(docs, text_messages, process_prompt_command)

        with trace.subspan("RAG_QUERY_ORIGINAL_LLM") as complete_span:

            resp = completion.execute()

            raw_resp = resp._raw
            if not resp.success:
                raise Exception("LLM call failed: %s" % raw_resp.get("errorMessage", "Unknown error"))

            if "trace" in resp._raw:
                complete_span.append_trace(resp._raw["trace"])

        # Build the result text and sources
        (result_text, text_content, sources_separated) = self._get_result_text(resp.text, docs)

        rag_response: "SimpleCompletionResponseOrError" = {}

        # forward functional fields from the augmented LLM, only if they have been used
        if "finishReason" in raw_resp:
            rag_response["finishReason"] = raw_resp["finishReason"]
        if "toolCalls" in raw_resp:
            rag_response["toolCalls"] = raw_resp["toolCalls"]
        if "logProbs" in raw_resp:
            rag_response["logProbs"] = raw_resp["logProbs"]
        if "additionalInformation" in raw_resp:
            rag_response["additionalInformation"] = raw_resp["additionalInformation"]

        # but for usage metadata, it's not incurred here, but only reported, so put them in the reported field
        rag_response["reportedUsageMetadata"] = {
            "promptTokens": raw_resp.get("promptTokens", None),
            "completionTokens": raw_resp.get("completionTokens", None),
            "totalTokens": raw_resp.get("totalTokens", None),
            "tokenCountsAreEstimated": raw_resp.get("tokenCountsAreEstimated", None),
            "estimatedCost": raw_resp.get("estimatedCost", None)
        }

        # set the RAG result text field
        rag_response["text"] = result_text
        rag_response["ok"] = True

        # set sources field if we have them
        if self.output_format == "SEPARATED" and sources_separated is not None:
            rag_response["sources"] = sources_separated

        if len(self.guardrails) != 0:
            if not sources_separated:
                sources_separated = self.get_sources_separated(docs)

            if self.retrieval_source == RetrievalSource.MULTIMODAL:
                sources = sources_separated  # parsing will be done in ragas_utils
            else:
                sources = [source["excerpt"]["text"] for source in sources_separated]

            rag_response = self._check_guardrails(rag_response=rag_response,
                                                  completion_settings=process_prompt_command["settings"],
                                                  question=question,
                                                  answer=text_content,
                                                  sources=sources,
                                                  trace=trace)

        rag_response["trace"] = trace.to_dict()
        rag_response.setdefault("additionalInformation", {})["sources"] = [{"items": [self.sources_handler.build_role_based_source_from(doc) for doc in docs]}]

        return rag_response

    async def process_prompt_stream(self, process_prompt_command: ProcessSinglePromptCommand) -> AsyncGenerator[
        Union[DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter], None]:
        assert self.started, "Not started"

        trace: SpanBuilder = new_trace("DKU_BUILTIN_RAG")
        trace.__enter__()

        (text_messages, relevant_messages, question) = self._extract_messages_and_question(process_prompt_command)
        question = await asyncio.get_running_loop().run_in_executor(None, self._query_rewriter, text_messages, question, trace)
        docs = []
        if self._should_query_kb(question):
            docs = await asyncio.get_running_loop().run_in_executor(None, self._query_retriever, process_prompt_command, question, trace)
        completion = self._get_completion(docs, text_messages, process_prompt_command)

        original_footer = None

        with trace.subspan("RAG_QUERY_ORIGINAL_LLM") as complete_span:

            logger.debug("Querying original LLM")

            # Small helper to tranform completion.execute_streamed to an async generator
            async def _aexecute_streamed_completion(comp):
                async_iterator = await asyncio.get_running_loop().run_in_executor(None, comp.execute_streamed)
                done = object()
                while True:
                    item = await asyncio.get_running_loop().run_in_executor(None, next, async_iterator, done)
                    if item is done:
                        break
                    yield item

            async for chunk in _aexecute_streamed_completion(completion):
                if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                    yield chunk

                elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    original_footer = chunk

                else:
                    # Weird. Are we querying an underlying agent maybe? Log but still send it
                    logger.info("Unexpected chunk type in RAG original completion call: %s" % type(chunk))
                    yield chunk

            assert original_footer is not None, "No footer received from original LLM completion call"
            if hasattr(original_footer, "trace"):
                complete_span.append_trace(original_footer.trace)

        if self.output_format == "TEXT":
            sources_text = self.get_sources_as_text(docs)
            yield DSSLLMStreamedCompletionChunk({"text": "\n\n" + sources_text})

        assert self.output_format != "JSON", "JSON output format is not supported in streaming mode"

        footer = self.prepare_footer(original_footer, trace, docs)

        # Streamed guardrails are not supported
        assert len(self.guardrails) == 0, "Streaming guardrails are not supported"

        yield footer

    def prepare_footer(self,
                       original_footer: DSSLLMStreamedCompletionFooter,
                       trace: SpanBuilder,
                       docs: List["Document"]) -> DSSLLMStreamedCompletionFooter:
        footer_data = {}

        fields_to_copy = [
            "finishReason",
            "toolCalls",
            "logProbs",
            # Usage metadata
            "promptTokens",
            "completionTokens",
            "totalTokens",
            "tokenCountsAreEstimated",
            "estimatedCost",
            # Extra info
            "additionalInformation",
        ]

        # forward functional fields from the augmented LLM, only if they have been used
        for field in fields_to_copy:
            if field in original_footer.data:
                footer_data[field] = original_footer.data[field]
        footer_data.setdefault("additionalInformation", {})
        footer_data["additionalInformation"]["sources"] = [{"items": [self.sources_handler.build_role_based_source_from(doc) for doc in docs]}]

        footer_data["trace"] = trace.to_dict()
        return DSSLLMStreamedCompletionFooter(footer_data)

    @staticmethod
    def _extract_messages_and_question(process_prompt_command: ProcessSinglePromptCommand) -> Tuple[List[ChatMessage], List[ChatMessage], str]:
        # strip non-text inputs and don't use parts
        text_messages: List[ChatMessage] = []
        for message in process_prompt_command["query"]["messages"]:
            new_message = message.copy()
            if new_message.get("parts", False):
                new_message["content"] = "\n".join(part["text"] for part in new_message["parts"] if part["type"] == "TEXT")  # type: ignore
                # ignoring because part["text"] is not null if "type" == "TEXT"
                del new_message["parts"]

            text_messages.append(new_message)

        relevant_messages = [
            m for m in text_messages
            if _is_relevant_for_vector_store_query(m)
        ]

        logger.info("Will only include these messages in the query: %s" % relevant_messages)
        question = "\n\n".join([m["content"] for m in relevant_messages])

        return text_messages, relevant_messages, question

    @staticmethod
    def _extract_caller_security_tokens(process_prompt_command: ProcessSinglePromptCommand) -> List[str]:
        try:
            raw_tokens = process_prompt_command["query"]["context"]["callerSecurityTokens"]
        except (KeyError, TypeError):
            raise ValueError("Unable to process query: Retrieval-Augmented LLM with document-level security requires valid security tokens. Please ensure your security tokens are included with your query context.")
        if not isinstance(raw_tokens, list) or not all(isinstance(token, str) for token in raw_tokens):
            raise ValueError("Invalid format: 'callerSecurityTokens' must be a JSON Array of Strings")

        return raw_tokens

    def _query_rewriter(self, relevant_messages: List[ChatMessage], question: str, trace: SpanBuilder) -> str:
        if self.search_input_strategy == "RAW_QUERY":
            return question

        if self.search_input_strategy == "REWRITE_QUERY":
            logger.info("Asking the LLM to rewrite this question: %s" % question)

            retrieval_prompt = self.rag_settings["searchInputStrategySettings"]["conditionalRetrievalPrompt"]
            rewrite_prompt = self.rag_settings["searchInputStrategySettings"]["rewritePrompt"]

            with self.lock:
                with trace.subspan("RAG_REWRITE_QUERY") as rewrite_span:
                    completion = self.llm.new_completion()
                    completion.with_message(retrieval_prompt, role="system")
                    completion.with_message(rewrite_prompt, role="system")
                    for message in relevant_messages:
                        completion.with_message(message['content'], role=message.get("role", "user"))
                    resp = completion.execute()

                    if not resp.success:
                        raise RuntimeError("LLM call failed: %s" % resp._raw.get("errorMessage", "Unknown error"))
                    if "trace" in resp._raw:
                        rewrite_span.append_trace(resp._raw["trace"])

                    return resp.text

        raise ValueError("Unsupported strategy in searchInputStrategySettings: %s" % self.search_input_strategy)

    def _should_query_kb(self, question: str) -> bool:
        return not (self.search_input_strategy == "REWRITE_QUERY" and self.no_retrieval_key in question)

    def _query_retriever(self, command: ProcessSinglePromptCommand, question: str, trace: SpanBuilder) -> List["Document"]:
        logger.info("Asking this question to the vector store: %s" % question)
        with self.lock:
            with trace.subspan("RAG_RETRIEVE_DOCUMENTS") as retrieve_span:
                additional_search_kwargs: Dict = {}
                if self.rag_settings.get("enforceDocumentLevelSecurity"):
                    caller_security_tokens = self._extract_caller_security_tokens(command)
                    if len(caller_security_tokens) == 0:
                        logger.warning("No caller security tokens provided: returning 0 documents")
                        return []
                    self.dku_vector_store.add_security_filter(additional_search_kwargs, caller_security_tokens)
                if self.filter is not None:
                    self.dku_vector_store.add_filter(additional_search_kwargs, self.filter)

                if callerFilters := command["query"].get("context", {}).get("callerFilters"):
                    self.dku_vector_store.add_dynamic_filter(additional_search_kwargs, callerFilters, None)

                retriever = self.dku_vector_store.as_retriever(self.embeddings, self.rag_settings, additional_search_kwargs)
                docs = retriever.invoke(question)

                retrieve_span.attributes["nbRetrievedDocuments"] = len(docs)

                # See comment in DKUEmbeddings about how to retrieve the trace from the last call
                if hasattr(self.embeddings, "_last_trace") and hasattr(self.embeddings._last_trace, "trace"):
                    retrieve_span.append_trace(self.embeddings._last_trace.trace)
                else:
                    logger.info("No last trace found in DKUEmbeddings after rag document retrieval")

        logger.info("Retrieved %s documents" % len(docs))

        return docs

    def _get_completion(self, docs: List["Document"],
                        text_messages: List[ChatMessage],
                        process_prompt_command: ProcessSinglePromptCommand) -> DSSLLMCompletionQuery:
        rag_system_message, rag_context_message = self._get_rag_prompts(docs)
        final_messages = get_messages_to_send(text_messages,
                                              rag_system_message,
                                              rag_context_message)
        completion = self.llm.new_completion()
        _set_chat_messages(completion, final_messages)

        completion.settings.update(self.completion_settings)
        # re-use the completion query settings (from the command)
        settings = process_prompt_command["settings"]
        completion.settings.update(settings)
        return completion

    def _get_result_text(self, resp_text: str, docs: List["Document"]) -> Tuple[Union[str, Dict], str, Optional[List[Source]]]:
        # There can be no-text when the response contains tool calls
        text_content = resp_text or ""

        # Eventual structured sources
        sources_separated: Optional[List[Source]] = None
        result_text: Union[str, Dict]
        if self.print_sources:
            if self.output_format == "TEXT":
                sources_text = self.get_sources_as_text(docs)
                if text_content:
                    result_text = "%s\n\n%s" % (text_content, sources_text)
                else:
                    result_text = sources_text
            elif self.output_format == "JSON":
                sources_json = self.get_sources_as_json(docs)
                result_text = {
                    "result": text_content,
                    "sources": sources_json,
                }
            elif self.output_format == "SEPARATED":
                result_text = text_content
                sources_separated = self.get_sources_separated(docs)
            else:
                raise ValueError(f"Unknown output format: {self.output_format}")
        else:
            result_text = text_content

        if not isinstance(result_text, str):
            result_text = json.dumps(result_text)

        return result_text, text_content, sources_separated

    def get_sources_separated(self, docs: List["Document"]) -> List[Source]:
        sources_separated: List[Source] = []
        for sd in docs:
            source: Source = {"metadata": self.document_handler.get_metadata_columns(sd).copy()}

            if self.include_content_in_sources:
                parts = self.document_handler.get_multipart_content(sd, image_retrieval=ImageRetrieval.IMAGE_REF)

                excerpt: Dict
                # We should stay compatible with classic KB
                if len(parts) == 1 and isinstance(parts[0], TextPart):
                    excerpt = {  # Use 'excerpt' rather than 'content' to partially align on Dataiku Answer's format
                        "type": "TEXT",
                        "text": parts[0].to_text()
                    }
                else:
                    # Can't have text in multiple parts, and we cannot have parts with different types for a given document
                    excerpt = {"type": "IMAGE_REF", "images": []}
                    for part in parts:
                        assert type(part) in (ImageRefPart, CaptionedImageRefPart), f"Unknown part type: {type(part)}"
                        excerpt["images"].append({
                            "fullFolderId": part.full_folder_id,
                            "path": part.path
                        })
                if excerpt:
                    source["excerpt"] = excerpt

            sources_separated.append(source)
        return sources_separated

    def get_sources_as_json(self, docs: List["Document"]) -> List[Dict]:
        sources = []
        for sd in docs:
            source = self.document_handler.get_metadata_columns(sd).copy()
            if self.include_content_in_sources:
                parts = self.document_handler.get_multipart_content(sd, image_retrieval=ImageRetrieval.IMAGE_REF)
                source["content"] = ", ".join([part.to_text() for part in parts])
            sources.append(source)
        return sources

    def get_sources_as_text(self, docs: List["Document"]) -> str:
        lines = []
        for sd in docs:
            formatted_metadata = [
                "%s: %s" % (k, v) for (k, v) in self.document_handler.get_metadata_columns(sd).items()
            ]
            if self.include_content_in_sources:
                parts = self.document_handler.get_multipart_content(sd, image_retrieval=ImageRetrieval.IMAGE_REF)
                formatted_parts = [part.to_text() for part in parts]
                formatted_metadata.append(u"content: %s" % (", ".join(formatted_parts)))
            lines.append("* %s" % (", ".join(formatted_metadata)))
        return "Sources:\n%s" % "\n".join(lines)

    def _get_rag_prompts(self, docs: List["Document"]) -> Tuple[ChatMessage, ChatMessage]:
        system_message = self.rag_settings["contextMessage"]
        if docs:
            multipart_content = MultipartContext()
            for index, d in enumerate(docs):
                # Build multipart content
                parts = self.document_handler.get_multipart_content(d, image_retrieval=ImageRetrieval.IMAGE_INLINE)
                multipart_content.parts.extend(parts)
            logger.info("Context: %s" % multipart_content.to_text())
            rag_context_message: ChatMessage = {"role": "user", "multipart_content": multipart_content}
        else:
            context_message = "No additional context is needed, answer directly to the following message:"
            logger.info("Context: %s" % context_message)
            rag_context_message: ChatMessage = {"role": "system", "content": context_message}

        rag_system_message: ChatMessage = {"role": "system", "content": system_message}
        return rag_system_message, rag_context_message

    def _check_guardrails(self,
                          rag_response: "SimpleCompletionResponseOrError",
                          completion_settings: CompletionSettings,
                          question: str,
                          answer: str,
                          sources: List[str],
                          trace: SpanBuilder) -> "SimpleCompletionResponseOrError":
        ragas_metrics_computer = self._get_ragas_metrics_computer(completion_settings)
        mapped_input: GenAIMetricInput = GenAIMetricInput.from_single_entry(question, answer, None, sources, None, None)

        with trace.subspan("RAG_GUARDRAILS_METRICS_COMPUTE") as metrics_span:
            overall_metrics, _ = ragas_metrics_computer.compute_llm_metrics(metric_inputs=mapped_input, metrics=self.metrics_to_compute, trace=metrics_span)

        for guardrail in self.guardrails:
            metric_value = overall_metrics[guardrail.metric_key_name]
            if metric_value is None or math.isnan(metric_value):  # can happen when the metrics computation failed
                return self._return_guardrails_error(rag_response, f"RAG guardrail failure: failed to compute '{guardrail.metric_key_name}' guardrail metric.",
                                                     "ERROR")
            if metric_value < guardrail.threshold:
                if guardrail.below_threshold_handling == "OVERWRITE_ANSWER":
                    logger.warning(
                        f"RAG guardrail: '{guardrail.metric_key_name}' = '{metric_value}' below '{guardrail.threshold}'. Replacing '{rag_response['text']}' answer with '{guardrail.answer_overwrite}'.")
                    rag_response["text"] = guardrail.answer_overwrite
                    rag_response.pop("sources", None)  # we unset the sources so they don't appear alongside the overwritten answer
                else:  # guardrail.below_threshold_handling == "FAIL":
                    return self._return_guardrails_error(rag_response, f"Rejected by RAG guardrail: {guardrail.get_error_message(metric_value)}")
        return rag_response

    def _return_guardrails_error(self, rag_response: "SimpleCompletionResponseOrError", error_message: str,
                                 error_type: "LLMErrorType" = "REFUSAL") -> "SimpleCompletionResponseOrError":
        rag_response["ok"] = False
        rag_response["errorType"] = error_type
        rag_response["errorSource"] = "GUARDRAIL"
        rag_response["errorCode"] = "ERR_LLM_RESPONSE_GUARDRAIL"
        rag_response["errorMessage"] = error_message
        return rag_response

    def _get_ragas_metrics_computer(self, completion_settings: CompletionSettings) -> "RagasMetricsComputer":
        from importlib.metadata import version
        ragas_version = version("ragas")
        can_compute_multimodal_metrics = False
        if not ragas_version:
            raise Exception("ragas package is missing.")
        if ragas_version.startswith("0.1"):
            if self.retrieval_source == RetrievalSource.MULTIMODAL:
                raise Exception("you need ragas >= 0.2 and python >= 3.9 to use multimodal guardrails.")
            from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_1_10 import RagasMetricsComputer
        elif ragas_version.startswith("0.2"):
            from dataiku.llm.evaluation.utils.ragas.ragas_utils_0_2_12 import RagasMetricsComputer
            can_compute_multimodal_metrics = True
        else:
            raise Exception(f"Version of ragas {ragas_version} is not supported (only 0.1.X and 0.2.X)")
        return RagasMetricsComputer(
            completion_llm_id=self.guardrails_llm_id,
            completion_settings=completion_settings,
            embedding_llm_id=self.guardrails_embedding_model_id,
            max_workers=9,
            fail_on_row_level_errors=True,
            can_compute_multimodal_metrics=can_compute_multimodal_metrics)

    async def handler(self, command: Dict) -> AsyncGenerator:
        if command["type"] == "start":
            logger.info("Received start command: %s" % command)
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "process-completion-query":
            logger.debug("Received command: %s", command)
            logger.info("\n===============  Start completion query - run %s ===============", self.run_counter)
            try:
                if command["stream"]:
                    async for c in self.process_prompt_stream(command):
                        if isinstance(c, DSSLLMStreamedCompletionChunk):
                            yield {"chunk": c.data}
                        else:
                            yield {"footer": c.data}
                else:
                    yield await asyncio.get_running_loop().run_in_executor(
                        self.executor, self.process_prompt, command
                    )
                logger.info("\n===============  End completion query - run %s ===============", self.run_counter)
            finally:
                self.run_counter += 1
        else:
            raise Exception("Unknown command type: %s" % command["type"])

def get_messages_to_send(text_messages: List[ChatMessage], rag_system_message: ChatMessage, rag_context_message: ChatMessage) -> List[ChatMessage]:
    # Since the chat conversation may contain tool calls and tool outputs,
    # we must be cautious about where to insert the RAG instructions and
    # context. Hence, we insert the special messages right before the last
    # user question in the conversation history.
    user_messages_indices = [
        i for i, msg in enumerate(text_messages)
        if msg.get("role") == "user"
    ]

    if len(user_messages_indices) == 0:
        raise Exception("The chat history must contain at least one message with role: 'user'")

    last_user_message_index = user_messages_indices[-1]
    final_messages = text_messages.copy()
    final_messages.insert(last_user_message_index, rag_system_message)
    final_messages.insert(last_user_message_index + 1, rag_context_message)
    return final_messages

def log_exception(loop: Any, context: Dict) -> None:
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )

def _is_relevant_for_vector_store_query(message: ChatMessage) -> bool:
    tool_calls = message.get("toolCalls")
    if tool_calls and len(tool_calls) > 0:
        return False

    tool_outputs = message.get("toolOutputs")
    if tool_outputs and len(tool_outputs) > 0:
        return False

    return message["role"] not in ["system", "tool"] \
        and not message.get("partOfExample", False)

def _set_chat_messages(query: DSSLLMCompletionQuery, chat_messages: List[ChatMessage]) -> None:
    for message in chat_messages:
        role = message["role"]
        content = message.get("content")
        multipart_content: Optional[MultipartContext] = message.get("multipart_content")
        tool_calls = message.get("toolCalls")
        tool_outputs = message.get("toolOutputs")

        if tool_calls and len(tool_calls) > 0:
            query.with_tool_calls(tool_calls, role=role)

            # some providers like Anthropic may submit tool calls
            # together with a text content.
            if multipart_content:
                if multipart_content.is_text_only():
                    query.cq["messages"][-1]["content"] = multipart_content.to_text()
                else:
                    raise Exception("Unsupported image content for tool messages")

        elif tool_outputs and len(tool_outputs) > 0:
            for tool_output in tool_outputs:
                query.with_tool_output(
                    tool_output["output"],
                    tool_call_id=tool_output["callId"],
                    role=role
                )

        elif multipart_content:
            multipart_content.add_to_completion_query(query, role=role)

        elif content:
            query.with_message(content, role=role)

        else:
            logger.warning("Chat message seems empty, excluding it from the RAG query" % message)

if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    watch_stdin()

    async def start_server() -> None:
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = RAGServer()

        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
