import json
import uuid
from typing import Any, Dict, List, Optional, Union

from common.backend.constants import KEYS_TO_REMOVE_FROM_LOGS
from common.backend.models.base import LLMContext, LlmHistory, LLMStep, MediaSummary
from common.backend.models.source import AggregatedToolSources
from common.backend.utils.auth_utils import get_auth_agent, get_auth_user
from common.backend.utils.context_utils import reset_conv_id, set_conv_id
from common.backend.utils.json_utils import mask_keys
from common.backend.utils.streaming_utils import TaggedTextStreamHandler
from common.backend.utils.uploaded_files_utils import (
    extract_logging_uploaded_files_info,
    extract_uploaded_docs_from_history,
)
from common.llm_assist.llm_api_handler import llm_setup
from common.llm_assist.logging import logger
from flask_socketio import disconnect, emit, send
from portal.backend.db.conversations import conversation_sql_manager
from portal.backend.models import PortalExtractedQueryInfo
from portal.backend.routes.api.api_response_handler import APIResponseProcessor
from portal.backend.utils.conversation_utils import record_conversation
from portal.solutions.service import LLM_Question_Answering

from ..utils import before_request, return_ko, return_ok


def extract_request_info(request_json: Dict[str, Union[str, List[str]]]) -> PortalExtractedQueryInfo:
    """Parses a JSON payload to extract necessary details as IDs, query, filters,
    and history using global context and SQL management systems.

    Parameters:
        request_json (dictionary): JSON request data.

    Returns:
        ExtractedQueryInfo: Key information like query, filters, and IDs.

    Note:
        Requires global authentication and APIResponseProcessor for processing.
    """
    auth_identifier = get_auth_user()
    api_response_handler: APIResponseProcessor = APIResponseProcessor(api_response=request_json)
    conversation_id: Optional[str] = api_response_handler.extract_conversation_id()
    is_new_conversation = True
    query: str = api_response_handler.extract_query()
    query_index: int = api_response_handler.extract_query_index()

    chain_type: Optional[str] = api_response_handler.extract_chain_type()
    media_summaries = api_response_handler.extract_media_summaries()
    # Get Llm history from cache
    history: List[LlmHistory] = []
    conversation_name = None
    llm_context: LLMContext = {}
    previous_media_summaries = []
    mapped_agents_files_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None
    if conversation_id:
        logger.debug(f"New request in the conversation '{conversation_id}'")
        is_new_conversation = False
        history, conversation_name, mapped_agents_files_uploads = conversation_sql_manager.get_conversation_history(
            auth_identifier=auth_identifier, conversation_id=conversation_id
        )
        previous_media_summaries = extract_uploaded_docs_from_history(history)
        logger.debug(f"Previous Media summaries: {extract_logging_uploaded_files_info(previous_media_summaries)}")
    else:
        conversation_id = str(uuid.uuid4())
        logger.debug(f"The new conversation_id '{conversation_id}' will be created")
    answer: str = api_response_handler.extract_answer()
    sources: List[AggregatedToolSources] = api_response_handler.extract_sources()
    conversation_type = api_response_handler.extract_conversation_type()
    return {
        "query": query,
        "query_index": query_index,
        "chain_type": chain_type,
        "media_summaries": media_summaries,
        "previous_media_summaries": previous_media_summaries,
        "conversation_name": conversation_name,
        "history": history,
        "answer": answer,
        "sources": sources,
        "conversation_id": conversation_id,
        "is_new_conversation": is_new_conversation,
        "llm_context": llm_context,
        "user_profile": api_response_handler.extract_user_profile(),
        "generated_images": [],
        "conversation_type": conversation_type,
        "app_id": api_response_handler.extract_app_id(),
        "agents_files_uploads": mapped_agents_files_uploads,
    }


def filter_llm_context_for_ui(llm_context: Dict[str, Any]) -> Dict[str, Any]:
    # Extracting specific values from llm_context to send to the UI
    return {key: llm_context[key] for key in ["llm_kb_selection", "dataset_context"] if key in llm_context}


def send_error_response():
    emit("log_query_complete", return_ko(message="Error occurred while processing your message"))


def process_and_stream_answer(request: str):
    """
    Handles a request by streaming llm answers and logging the query process.

    This function processes a JSON request to provide streaming answers and log interactions.
    It concludes by saving the query in the logging database.

    Parameters:
    - request (str): The JSON string of the request, which should include the query, and may include filters, knowledge bank ID, and file path.

    Key Steps:
    1. Parse the JSON request.
    2. Extract request details like query and optional parameters.
    3. Retrieve and stream llm answer parts.
    4. Log the completed query in the database.

    Error handling ensures clients are informed of any issues during the process.
    """

    def finalize_answer(request_info: PortalExtractedQueryInfo):
        """
        Finalizes the answer by storing query request into the Logging DB.

        Parameters:
        - request_json (Dict[str, Union[str, List[str]]]): The request information including the answer.
        """
        logger.debug("Finalizing answer")
        record_info = record_conversation(info=request_info)
        if not record_info:
            logger.exception(f"Sending error response")
            send_error_response()
        else:
            logger.debug("Log query complete")
            new_record = record_info["new_record"]
            images = record_info["images"]
            user_profile = record_info["user_profile"]
            record_id = record_info["record_id"]
            conversation_infos = record_info["conversation_infos"]
            emit(
                "log_query_complete",
                return_ok(
                    data={
                        "answer": new_record["answer"],
                        "sources": new_record["sources"],
                        "filters": None,
                        "conversation_infos": conversation_infos,
                        "record_id": record_id,
                        "query": new_record["query"],
                        "generated_images": images,
                        "user_profile": user_profile,
                    }
                ),
            )
        disconnect()

    answer = ""
    request_info: PortalExtractedQueryInfo = {}

    try:
        request_json: Dict[str, Union[str, List[str]]] = json.loads(request)
        logger.debug(f"Payload from the front end is {request_json}")
        request_info = extract_request_info(request_json=request_json)
    except Exception as e:
        # Handle exceptions gracefully
        logger.exception(f"The request could not be parsed. Service Got exception:{e}")
        error_msg = "Error processing your request (The request could not be parsed)"
        send(return_ok(data={"answer": error_msg}))
        if request_info:
            request_info["answer"] = error_msg
            finalize_answer(request_info)
        else:
            send_error_response()
    
    conversation_id = request_info["conversation_id"]
    conversation_token = set_conv_id(conversation_id)
    # The following log makes sure the answer process be tracked from the initial query to the conversation's end
    logger.debug(f"User query: {request_info['query']} (query index={request_info['query_index']})", log_conv_id=True)

    try:
        tagged_text_stream_handler = TaggedTextStreamHandler(["<citation"], ["</citation>"])
        llm_qa = LLM_Question_Answering(llm_setup.get_llm())

        for chunk in llm_qa.get_answer_and_sources(
            query=request_info["query"],
            conversation_type=request_info["conversation_type"],
            chat_history=request_info["history"],
            chain_type=request_info["chain_type"],
            media_summaries=request_info["media_summaries"],
            previous_media_summaries=request_info["previous_media_summaries"],
            user_profile=request_info["user_profile"],
            user=get_auth_user(),
            app_id=request_info.get("app_id"),
            user_agent=get_auth_agent(),
            agents_files_uploads=request_info.get("agents_files_uploads"),
            conversation_id=request_info["conversation_id"]
        ):
            chunk_type = get_answer_chunk_type(chunk)
            if "streaming" in chunk_type:
                # Case streaming:
                if chunk_type == "streaming_completion_chunk":
                    if chunk.data.get("type", "") == "event":
                        event_message = chunk.data.get("eventKind")
                        event_data = chunk.data.get("eventData", {})
                        if event_data:
                            formatted_data = "\n".join(f"  - {key}: {value}" for key, value in event_data.items())
                            event_message += f"\nDetails:\n{formatted_data}"
                        send(return_ok(data={"step": event_message}))
                    else:
                        text_chunk = chunk.data.get("text", "")
                        processed_text = tagged_text_stream_handler.process_text_chunk(text_chunk)
                        answer += processed_text
                        send(return_ok(data={"answer": processed_text}))
            elif chunk_type == "dictionary":
                if chunk.get("step"):
                    step = chunk.get("step")
                    if (
                        step == LLMStep.COMPUTING_PROMPT_WITHOUT_RETRIEVAL
                        or step == LLMStep.QUERYING_LLM_WITHOUT_RETRIEVAL
                    ):
                        send(return_ok(data={"step": "analyzing_no_retriever"}))
                    elif step == LLMStep.STREAMING_START:
                        send(return_ok(data={"step": "streaming"}))
                    elif step == LLMStep.GENERATING_IMAGE:
                        send(return_ok(data={"step": "generating_image"}))
                    else:
                        send(return_ok(data={"step": step.name}))
                else:
                    text_chunk = chunk.get("answer")
                    if text_chunk:
                        answer += text_chunk
                    llm_context = chunk.get("llm_context", {})
                    logger.debug(f"LLM context: {mask_keys(llm_context, KEYS_TO_REMOVE_FROM_LOGS)}", log_conv_id=True)
                    if llm_context:
                        filtered_context = filter_llm_context_for_ui(llm_context)
                    sources = chunk.get("sources")
                    if sources and len(sources) == 1:
                        # Here we handle the case where the answer is directly sent to the user from one agent 
                        # In that case sources do not contain the full answer. 
                        # So we append the answer to the sources in order to always have it regardless of one agent or multiple agents contain the agent answer.
                        sources[0]["answer"] = answer
                    send(
                        return_ok(
                            data={
                                "answer": text_chunk,
                                "sources": sources,
                                "llm_context": filtered_context if llm_context else None,
                                "generated_images": chunk.get("generated_images", None),
                                "user_profile": chunk.get("user_profile", None),
                            }
                        )
                    )
                    request_info.update(
                        {
                            "answer": answer,
                            "sources": sources,
                            "llm_context": chunk.get("llm_context"),
                            "generated_images": chunk.get("generated_images", None),
                            "user_profile": chunk.get("user_profile", None),
                        }
                    )
                    finalize_answer(request_info)
    except Exception as e:
        # Handle exceptions gracefully
        logger.exception(f"Service Got exception:{e}", log_conv_id=True)
        error_msg = "Error processing your request"
        send(return_ok(data={"answer": error_msg}))
        if request_info:
            request_info["answer"] = error_msg
            finalize_answer(request_info)
        else:
            send_error_response()
    finally:
        reset_conv_id(conversation_token)

def process_log_query(request: str):
    try:
        logger.debug("Client sent a message to log")
        request_json: Dict[str, Union[str, List[str]]] = json.loads(request)
        logger.debug(f"Payload from the front end is {request_json}")
        info = extract_request_info(request_json=request_json)
        record_conversation(info)
    except Exception as e:
        logger.exception(f"Failed to log query:{e}")
        send_error_response()


def setup_socketio_answer_event_handlers(socketio):
    """
    Registers event handlers for a socket.io server to handle client connections,
    messages, logging queries, disconnections, and connection errors.
    This function sets up the necessary event listeners for the socket.io server
    to interact with clients, for streaming answers.

    Parameters:
    - socketio: The socket.io server instance to which the event handlers are to be registered.

    Notes:
    - The event handlers include actions for client connect and disconnect events,
      message receipt, log query messages, and handling connection errors.
    """

    @socketio.on("connect")
    def handle_connect():
        before_request()
        logger.debug("Client connected")

    @socketio.on("message")
    def handle_message(msg):
        before_request()
        process_and_stream_answer(msg)

    @socketio.on("log_query")
    def handle_log_query(request: str):
        before_request()
        process_log_query(request)

    @socketio.on("disconnect")
    def handle_disconnect():
        before_request()
        logger.info("Client disconnected")
        # disconnect()

    @socketio.on("connect_error")
    def handle_error():
        logger.error("Failed to connect to client")


def get_answer_chunk_type(answer_chunk):
    answer_chunk_type = "None_chunk"
    if answer_chunk is not None:
        if isinstance(answer_chunk, dict):
            answer_chunk_type = "dictionary"
        else:
            chunk_as_string = str(answer_chunk)
            if "completion-chunk:" in chunk_as_string:
                answer_chunk_type = "streaming_completion_chunk"
            elif "completion-footer:" in chunk_as_string:
                answer_chunk_type = "streaming_completion_footer"

    return answer_chunk_type
