from datetime import datetime
from typing import Any, Dict, Generator, List, Optional, Union

from answers.backend.utils.parameter_helpers import load_n_top_sources_to_log
from answers.backend.utils.rag_sources import filter_chat_logs_rag_sources
from common.backend.constants import LOGGING_DATES_FORMAT
from common.backend.db.sql.tables_managers import GenericConversationSQL, GenericMessageSQL, GenericUserProfileSQL
from common.backend.models.base import (
    ConversationInfo,
    ConversationInsertInfo,
    ConversationType,
    GeneratedMedia,
    LlmHistory,
    LLMStep,
    MediaSummary,
    MessageInsertInfo,
    RetrievalSummaryJson,
    RetrieverMode,
)
from common.backend.models.source import AggregatedToolSources
from common.backend.services.llm_question_answering import BaseLLMQuestionAnswering
from common.backend.services.sources.sources_formatter import (
    format_sources_to_store,
    serialize_aggregated_sources_for_api,
)
from common.backend.utils.context_utils import get_main_trace_dict
from common.backend.utils.image_generation_utils import handle_images
from common.backend.utils.streaming_utils import TaggedTextStreamHandler
from common.backend.utils.uploaded_files_utils import (
    extract_logging_uploaded_files_info,
    extract_uploaded_docs_from_history,
)
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import (
    DSSLLMStreamedCompletionChunk,
    DSSLLMStreamedCompletionFooter,
)
from typing_extensions import Literal


class AnswersSSE:
    def __init__(
        self,
        event: Literal[
            "completion-chunk",
            "completion-end",
        ],
        data: Dict,
    ) -> None:
        self.event = event
        self.data = data


class ErrorSSE(AnswersSSE):
    def __init__(self, data: Dict) -> None:
        super().__init__("completion-end", {"error": data})


class AskService:
    @staticmethod
    def retrieve_history(
        platform: str, conversation_id: Optional[str], user: str, messages_sql_manager: GenericMessageSQL
    ) -> List[LlmHistory]:
        history: List[LlmHistory] = []
        logger.debug(f"retrieve_history with conversation_id: {conversation_id}", log_conv_id=True)
        try:
            messages = (
                messages_sql_manager.get_all_conversation_messages(
                    platform=platform,
                    user=user, conversation_id=conversation_id, only_present=True
                )
                if conversation_id
                else []
            )
            for message in messages:
                # TODO: handled uploads ?
                history.append({"input": message["query"], "output": message["answer"]})
        except Exception as e:
            logger.exception(f"Error retrieving history: {e}", log_conv_id=True)
        return history

    @staticmethod
    def get_used_retrieval(chunk: RetrievalSummaryJson) -> Dict:
        llm_context = chunk.get("llm_context", {})
        used_retrieval: Dict[str, Any] = {}
        dataset_context = llm_context.get("dataset_context")
        llm_kb_selection = llm_context.get("llm_kb_selection")
        selected_retrieval_info: Dict = llm_context.get("selected_retrieval_info", {})  # type: ignore
        if dataset_context:
            used_retrieval["type"] = RetrieverMode.DB.value
            used_retrieval["generatedSqlQuery"] = dataset_context.get("sql_query")
            used_retrieval["usedTables"] = dataset_context.get("tables_used")
            used_retrieval["alias"] = selected_retrieval_info.get("alias")
        elif llm_kb_selection:
            used_retrieval["type"] = RetrieverMode.KB.value
            used_retrieval["name"] = llm_kb_selection[0]
            used_retrieval["alias"] = selected_retrieval_info.get("alias")
        filters = chunk.get("filters")
        if filters:
            used_retrieval["filters"] = filters

        aggregated_sources: List[AggregatedToolSources] = chunk.get("sources", [])

        # Use temporary variable for readability and to clarify the data transformation steps
        serialized_sources_list: List[Dict[str, Any]] =  serialize_aggregated_sources_for_api(aggregated_sources)
        used_retrieval["sources"] = serialized_sources_list

        return used_retrieval

    @staticmethod
    def create_conversation(
        user: str,
        info: ConversationInsertInfo,
        conversations_sql_manager: GenericConversationSQL,
    ) -> ConversationInfo:
        return conversations_sql_manager.add_conversation(user=user, conversation_info=info)

    @staticmethod
    def store_message(
        user: str, platform: str, message: MessageInsertInfo, messages_sql_manager: GenericMessageSQL, new_conv: bool, conversations_sql_manager: Optional[GenericConversationSQL]=None
    ) -> str:
        try:
            timestamp = datetime.now().strftime(LOGGING_DATES_FORMAT)
            message_id = messages_sql_manager.add_message(user=user, message_info=message, timestamp=timestamp)
            if conversations_sql_manager and not new_conv:
                conversations_sql_manager.update_conversation_metadata(user=user,
                                                                       conversation_id=message.get("conversation_id"), # type: ignore
                                                                       column_updates={"updated_at": timestamp},
                                                                       platform=platform
                                                                       )
            return message_id
        except Exception as e:
            logger.exception(f"Error storing message: {e}", log_conv_id=True)
            if new_conv and message.get("conversation_id"):
                # handle the case where the message is not stored, DELETE conv if needed
                conversation_id: str = message["conversation_id"]  # type: ignore
                messages_sql_manager.delete_conversation_messages(platform=platform, user=user, conversation_id=conversation_id)
            raise Exception("Messaged not stored")

    @staticmethod
    def process_response(
        answer: str,
        request_data: Dict,
        completion_footer: Union[DSSLLMStreamedCompletionFooter, None],
        retrieval_summary: RetrievalSummaryJson,
        timestamp: float,
        llm_qa: BaseLLMQuestionAnswering,
        user_profile_sql_manager: GenericUserProfileSQL,
        messages_sql_manager: GenericMessageSQL,
        chat_history: Optional[List[LlmHistory]],
        conversations_sql_manager: Optional[GenericConversationSQL]=None,
        message_index: int = 0,
    ) -> Dict:
        query = request_data["query"]
        conv_id = request_data.get("conversationId", None)
        chat_settings = request_data.get("chatSettings", {})
        new_conv_info: Optional[ConversationInfo] = None
        platform = request_data.get("context", {}).get("applicationId")
        new_conv = False
        if not conv_id and chat_settings.get("createConversation", False) and conversations_sql_manager:
            conv_name = ""
            if chat_settings.get("withTitle", False):
                # generate a title
                # TODO: should we handle conv of type media ?
                conv_name = llm_qa.get_conversation_title(query, answer, retrieval_summary.get("user_profile", None))
            try:
                new_conv_info = AskService.create_conversation(
                    user=request_data["user"],
                    info=ConversationInsertInfo(name=conv_name, platform=platform),
                    conversations_sql_manager=conversations_sql_manager,
                )
                conv_id = new_conv_info["id"] if new_conv_info else None
                new_conv = True
            except Exception as e:
                logger.exception(f"Error creating conversation: {e}", log_conv_id=True)
                # TODO: How should we handle this error?
        message_id = None
        images: List[GeneratedMedia] = []
        generated_media = None
        try:
            user_profile = retrieval_summary.get("user_profile", {})
            generated_images = retrieval_summary.get("generated_images", [])
            if generated_images:
                images, user_profile = handle_images(generated_images, request_data["user"], user_profile, user_profile_sql_manager) # type: ignore
                generated_media = {"images": images}

            # Filter sources and rm image data from it before storage
            # rm images is done after the filter so it will have a default of [] if all items are filtered
            aggregated_sources: List[AggregatedToolSources] = retrieval_summary.get("sources", [])
            
            n_top_sources_to_log = load_n_top_sources_to_log()
            filtered_sources, __ = filter_chat_logs_rag_sources(
                aggregated_sources, n_top_sources_to_log, None
            )
            aggregated_sources = format_sources_to_store(aggregated_sources)
            

            message_id = AskService.store_message(
                user=request_data["user"],
                platform=platform,
                message=MessageInsertInfo(
                    query=query,
                    answer=answer,
                    platform=platform,
                    llm_name=llm_qa.get_llm_name(),
                    filters=retrieval_summary.get("filters"),
                    sources=filtered_sources, # type: ignore
                    llm_context=retrieval_summary.get("llm_context"),
                    generated_media=generated_media,
                    conversation_id=conv_id,
                    history=chat_history,
                ),
                messages_sql_manager=messages_sql_manager,
                new_conv=new_conv,
                conversations_sql_manager=conversations_sql_manager, # type: ignore
            ) # type: ignore
        except Exception as e:
            logger.exception(f"Error storing message: {e}", log_conv_id=True)
            # TODO: How should we handle this error?
            raise Exception("Error storing message")

        resp = {
            "id": message_id,
            "messageIndex": message_index,  # TODO do we need this?
            "answer": answer,
            "query": query,
            "timestamp": timestamp,
            "user_profile": user_profile,
            "context": {
                "applicationId": platform,
                "applicationType": request_data.get("context", {}).get("applicationType"),
            },
            "trace": get_main_trace_dict(query, answer),
        }
        if retrieval_summary.get("sources") or retrieval_summary.get("filters"):
            resp["usedRetrieval"] = AskService.get_used_retrieval(chunk=retrieval_summary)
        if generated_media:
            resp["generatedMedia"] = generated_media
        if new_conv_info:
            resp["conversationInfo"] = {
                "id": new_conv_info["id"],
                "title": new_conv_info["name"],
                "createdAt": new_conv_info["created_at"],
                "lastMessageAt": new_conv_info["updated_at"],
            }
        if llm_context := retrieval_summary.get("llm_context"):
            resp["llmContext"] = llm_context
        return resp

    @staticmethod
    def __build_conv_history__(
        data: Dict,
        messages_sql_manager: GenericMessageSQL,
    ) -> Any:
        provided_history = data.get("context", {}).get("history")
        chat_history: List[LlmHistory] = []
        previous_media_summaries = None
        media_summaries: Optional[List[MediaSummary]] = None
        if data.get("files"):
            media_summaries = [
                {
                    "original_file_name": file["name"],
                    "file_path": file["path"],
                    "chain_type": file["chainType"],
                    "metadata_path": file.get("jsonFilePath"),
                    "preview": file.get("preview"),
                }
                for file in data["files"]
            ]
        if provided_history:
            chat_history = [{"input": history["query"], "output": history["answer"]} for history in provided_history]
            previous_media_summaries = extract_uploaded_docs_from_history(chat_history)
            logger.debug(f"Previous Media summaries: {extract_logging_uploaded_files_info(previous_media_summaries)}", log_conv_id=True)
        else:
            chat_history = AskService.retrieve_history(
                platform=data.get("context", {}).get("applicationId"),
                conversation_id=data.get("conversationId"),
                user=data["user"],
                messages_sql_manager=messages_sql_manager,
            )
        return {
            "chat_history": chat_history,
            "previous_media_summaries": previous_media_summaries,
            "media_summaries": media_summaries,
        }

    @staticmethod
    def process_query(
        data: Dict,
        llm_qa: BaseLLMQuestionAnswering,
        user_profile_sql_manager: GenericUserProfileSQL,
        messages_sql_manager: GenericMessageSQL,
        conversations_sql_manager: Optional[GenericConversationSQL]=None,
    ) -> Dict:
        for event in AskService.process_query_streaming(
            request_data=data,
            llm_qa=llm_qa,
            user_profile_sql_manager=user_profile_sql_manager,
            messages_sql_manager=messages_sql_manager,
            conversations_sql_manager=conversations_sql_manager,
            force_non_streaming=True,
        ):
            if event.event == "completion-end":
                if event.data.get("error"):
                    raise Exception("Error when processing query.")
                else:
                    return event.data

        logger.error("Done processing query but there is no end to completion.", log_conv_id=True)
        raise Exception("End of processing never reached.")

    @staticmethod
    def process_query_streaming(
        request_data: Dict,
        llm_qa: BaseLLMQuestionAnswering,
        user_profile_sql_manager: GenericUserProfileSQL,
        messages_sql_manager: GenericMessageSQL,
        conversations_sql_manager: Optional[GenericConversationSQL]=None,
        force_non_streaming=False,
    ) -> Generator[AnswersSSE, Any, Any]:
        # conv_type = data["conversation_type"] TODO should we keep this in the api?
        # Conv_type will always be general in case of API.
        # chat with media : every question will be passed with the extracted information to llm
        # chat without media : every question will be passed without the extracted information to llm except for the first one
        conv_type = ConversationType.GENERAL
        #  will be moved to be no need to send
        # chain_type = data["chain_type"] TODO should we keep this in the api?
        # TODO media_summaries ? why is it being sent from the frontend. it can be handled directly by the backend

        history = AskService.__build_conv_history__(request_data, messages_sql_manager)
        chat_history = history["chat_history"]
        media_summaries = history["media_summaries"]
        previous_media_summaries = history["previous_media_summaries"]

        # Consumer can ask for streaming but the underlying LLM may not support it.
        # In this case, yield once the complete answer.
        complete_answer = ""
        completion_footer: Union[DSSLLMStreamedCompletionFooter, None] = None
        chunk_with_retrieval_summary: RetrievalSummaryJson = {}
        currentTime = datetime.now().timestamp()
        tagged_text_stream_handler = TaggedTextStreamHandler(["<citation"], ["</citation>"])
        for chunk in llm_qa.get_answer_and_sources(
            query=request_data["query"],
            conversation_type=conv_type,
            chat_history=chat_history,
            filters=request_data.get("selectedRetrieval", {}).get("filters"),
            chain_type="",
            media_summaries=media_summaries,
            previous_media_summaries=previous_media_summaries,
            retrieval_enabled=True,  # TODO should we have the option to disable retrieval?
            user_profile=request_data.get("userPreferences"),
            force_non_streaming=force_non_streaming,
        ):
            if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                answer_chunk = chunk.data.get("text", "")
                processed_text = tagged_text_stream_handler.process_text_chunk(answer_chunk)
                complete_answer += processed_text
                yield AnswersSSE("completion-chunk", {"text": processed_text})
            elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                logger.debug(f"Streaming footer received {chunk.data}.", log_conv_id=True)  # type: ignore
                completion_footer = chunk.data  # type: ignore
            elif isinstance(chunk, dict):
                if "answer" in chunk:
                    # Once we arrive at this point, the answer is ready and this chunk contains all the information regarding retrieval, sources...
                    chunk_with_retrieval_summary = chunk  # type: ignore

                    logger.debug(f"Final iteration with summary .", log_conv_id=True)

                    answer_chunk = chunk_with_retrieval_summary["answer"]
                    if answer_chunk:
                        # This is for non streaming LLM.
                        # Debatable if it is necessary to send this event even though the complete answer will be in the next event.
                        # Let's say from an API implementer point of view it helps make no difference between streaming and non streaming LLM.
                        complete_answer = answer_chunk
                        logger.debug(f"Streaming end received and footer is {completion_footer}.", log_conv_id=True)
                        yield AnswersSSE("completion-chunk", {"text": answer_chunk})

                    processed_response = AskService.process_response(
                        answer=complete_answer,
                        retrieval_summary=chunk_with_retrieval_summary,
                        completion_footer=completion_footer,
                        request_data=request_data,
                        chat_history=chat_history if chat_history else None,
                        timestamp=currentTime,
                        llm_qa=llm_qa,
                        user_profile_sql_manager=user_profile_sql_manager,
                        messages_sql_manager=messages_sql_manager,
                        conversations_sql_manager=conversations_sql_manager,
                        message_index=1 + len(chat_history) if chat_history else 0,
                    )
                    if not force_non_streaming:
                        processed_response["answer"] = ""

                    yield AnswersSSE(
                        "completion-end",
                        processed_response,
                    )
                elif step := chunk.get("step", None):
                    if step == LLMStep.STREAMING_END:
                        logger.debug(f"Streaming end received.", log_conv_id=True)
                    elif step == LLMStep.STREAMING_ERROR:
                        logger.debug(f"Streaming error received and footer is {completion_footer}.", log_conv_id=True)
                        yield ErrorSSE(
                            data={
                                "request_data": request_data,
                                "timestamp": currentTime,
                            },
                        )
                    else:
                        # For the moment ignore steps as it may open up to much the API and, to me at the moment, it is not clear how we want to expose that.
                        logger.debug(f"Ignore step {step}.", log_conv_id=True)
                        pass
                else:
                    logger.warn(f"Unexpected dictionary type chunk: '{chunk}'.", log_conv_id=True)
            else:
                # TODO: handle this case. It shouldn't happen
                logger.warn(f"Unexpected chunk type '{chunk}'.", log_conv_id=True)
