import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

from answers.backend.db.logging import logging_sql_manager
from answers.backend.routes.api.api_response_handler import APIResponseProcessor
from answers.backend.utils.db.conversation_utils import record_conversation
from answers.solutions.service import LLM_Question_Answering
from common.backend.constants import LLM_API_ERROR
from common.backend.models.base import (
    ExtractedQueryInfo,
    LLMContext,
    LlmHistory,
    LLMStep,
    RetrievalSummaryJson,
)
from common.backend.models.source import AggregatedToolSources
from common.backend.utils.auth_utils import get_auth_user
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.streaming_utils import TaggedTextStreamHandler
from common.backend.utils.uploaded_files_utils import (
    extract_logging_uploaded_files_info,
    extract_uploaded_docs_from_history,
)
from common.backend.utils.user_profile_utils import flatten_user_profile_values
from common.llm_assist.llm_api_handler import llm_setup
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk
from flask_socketio import disconnect, emit, send

from ..utils import before_request, return_ko, return_ok


def extract_request_info(request_json: Dict[str, Union[str, List[str]]]) -> ExtractedQueryInfo:
    """Parses a JSON payload to extract necessary details as IDs, query, filters,
    and history using global context and SQL management systems.

    Parameters:
        request_json (dictionary): JSON request data.

    Returns:
        ExtractedQueryInfo: Key information like query, filters, and IDs.

    Note:
        Requires global authentication and APIResponseProcessor for processing.
    """
    auth_identifier = get_auth_user()
    api_response_handler: APIResponseProcessor = APIResponseProcessor(api_response=request_json)
    conversation_id: Optional[str] = api_response_handler.extract_conversation_id()
    query: str = api_response_handler.extract_query()
    query_index: int = api_response_handler.extract_query_index()

    filters: Optional[Dict[str, List[Any]]] = api_response_handler.extract_filters()
    chain_type: Optional[str] = api_response_handler.extract_chain_type()
    media_summaries = api_response_handler.extract_media_summaries()
    logger.debug(f"Media summaries: {extract_logging_uploaded_files_info(media_summaries)}")
    knowledge_bank_id: Optional[str] = (
        api_response_handler.extract_knowledge_bank_id()
    )  # TODO: remove once we change to new logging structure
    retrieval_enabled: bool = api_response_handler.extract_retrieval_enabled()
    # Get Llm history from cache
    history: List[LlmHistory] = []
    conversation_name = None
    llm_context: LLMContext = {}
    previous_media_summaries = []
    if conversation_id:
        history, conversation_name = logging_sql_manager.get_conversation_history(
            auth_identifier=auth_identifier, conversation_id=conversation_id
        )
        previous_media_summaries = extract_uploaded_docs_from_history(history)
        logger.debug(f"Previous Media summaries: {extract_logging_uploaded_files_info(previous_media_summaries)}")
    answer: str = api_response_handler.extract_answer()
    sources: List[AggregatedToolSources] = api_response_handler.extract_sources()
    retrieval_selection: Dict[str, List[Any]] = api_response_handler.extract_retrieval_selection()
    conversation_type = api_response_handler.extract_conversation_type()
    
    return {
        "query": query,
        "query_index": query_index,
        "filters": filters,
        "chain_type": chain_type,
        "media_summaries": media_summaries,
        "previous_media_summaries": previous_media_summaries,
        "conversation_name": conversation_name,
        "history": history,
        "answer": answer,
        "sources": sources, # type: ignore
        "conversation_id": conversation_id,
        "knowledge_bank_id": knowledge_bank_id,  # TODO: remove once we change to new logging structure
        "retrieval_enabled": retrieval_enabled,
        "retrieval_selection": retrieval_selection,
        "llm_context": llm_context,
        "user_profile": api_response_handler.extract_user_profile(),
        "generated_images": [],
        "conversation_type": conversation_type,
    }


def filter_llm_context_for_ui(llm_context: LLMContext) -> Dict[str, Any]:
    # Extracting specific values from llm_context to send to the UI
    return {key: llm_context[key] for key in ["llm_kb_selection", "dataset_context"] if key in llm_context}  # type: ignore


def send_error_response():
    emit("log_query_complete", return_ko(message="Error occurred while processing your message"))


def process_and_stream_answer(request: str):
    """
    Handles a request by streaming llm answers and logging the query process.

    This function processes a JSON request to provide streaming answers and log interactions.
    It concludes by saving the query in the logging database.

    Parameters:
    - request (str): The JSON string of the request, which should include the query, and may include filters, knowledge bank ID, and file path.

    Key Steps:
    1. Parse the JSON request.
    2. Extract request details like query and optional parameters.
    3. Retrieve and stream llm answer parts.
    4. Log the completed query in the database.

    Error handling ensures clients are informed of any issues during the process.
    """

    def finalize_answer(request_info: ExtractedQueryInfo):
        """
        Finalizes the answer by storing query request into the Logging DB.

        Parameters:
        - request_json (Dict[str, Union[str, List[str]]]): The request information including the answer.
        """
        record_info = record_conversation(user=get_auth_user(), info=request_info)
        if not record_info:
            logger.exception(f"Sending error message")
            send_error_response()
        else:
            logger.debug("Log query complete")
            new_record = record_info["new_record"]
            images = record_info["images"]
            user_profile = record_info["user_profile"]
            record_id = record_info["record_id"]
            conversation_infos = record_info["conversation_infos"]
            emit(
                "log_query_complete",
                return_ok(
                    data={
                        "answer": new_record["answer"],
                        "sources": new_record["sources"],
                        "filters": new_record["filters"],
                        "conversation_infos": conversation_infos,
                        "record_id": record_id,
                        "query": new_record["query"],
                        "generated_images": images,
                        "user_profile": user_profile,
                    }
                ),
            )

        disconnect()

    answer = ""

    request_info = None

    try:
        request_json: Dict[str, Union[str, List[str]]] = json.loads(request)
        logger.debug(f"Payload from the front end is {request_json}")
        request_info = extract_request_info(request_json=request_json)
        
        # The 'user_profile' key is altered in the 'request_info': save of the initial 'user_profile' to access it later
        initial_user_profile :Dict[str, Any] = deepcopy(request_info.get("user_profile", {})) # type: ignore
        tagged_text_stream_handler = TaggedTextStreamHandler(["<citation"], ["</citation>"])
        
        llm_qa = LLM_Question_Answering(llm_setup.get_llm())

        for chunk in llm_qa.get_answer_and_sources(
            query=request_info["query"],
            conversation_type=request_info["conversation_type"],
            chat_history=request_info["history"],
            filters=request_info["filters"],
            chain_type=request_info["chain_type"],
            media_summaries=request_info["media_summaries"],
            previous_media_summaries=request_info["previous_media_summaries"],
            retrieval_enabled=request_info["retrieval_enabled"],
            user_profile=request_info["user_profile"],
        ):
            if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                # Case streaming:
                text_chunk = chunk.data.get("text", "")
                # Only look for citations if they are enabled
                enable_llm_citations = dataiku_api.webapp_config.get("enable_llm_citations")
                processed_text = tagged_text_stream_handler.process_text_chunk(text_chunk) if enable_llm_citations else text_chunk
                answer += processed_text
                send(return_ok(data={"answer": processed_text}))

            elif isinstance(chunk, dict):
                if chunk.get("step"):
                    step = chunk.get("step")
                    if step == LLMStep.COMPUTING_PROMPT_WITH_KB or step == LLMStep.QUERYING_LLM_WITH_KB:
                        send(return_ok(data={"step": "analyzing_kb"}))
                    elif step == LLMStep.COMPUTING_PROMPT_WITH_DB or step == LLMStep.QUERYING_LLM_WITH_DB:
                        send(return_ok(data={"step": "analyzing_db"}))
                    elif (
                        step == LLMStep.COMPUTING_PROMPT_WITHOUT_RETRIEVAL
                        or step == LLMStep.QUERYING_LLM_WITHOUT_RETRIEVAL
                    ):
                        send(return_ok(data={"step": "analyzing_no_retriever"}))
                    elif step == LLMStep.STREAMING_START:
                        send(return_ok(data={"step": "streaming"}))
                    elif step == LLMStep.GENERATING_IMAGE:
                        send(return_ok(data={"step": "generating_image"}))
                    elif step == LLMStep.USING_FALLBACK_LLM:
                        send(return_ok(data={"step": "using_fallback_llm"}))
                else:
                    text_chunk = chunk.get("answer")
                    if text_chunk:
                        answer += text_chunk
                    llm_context: LLMContext = chunk.get("llm_context", {})  # type: ignore
                    logger.debug(f"llm_context: {llm_context}")
                    filtered_context = filter_llm_context_for_ui(llm_context)
                    send(
                        return_ok(
                            data={
                                "answer": text_chunk,
                                "sources": chunk.get("sources"),
                                "filters": chunk.get("filters"),
                                "llm_context": filtered_context,
                                "generated_images": chunk.get("generated_images", None),
                                "user_profile": chunk.get("user_profile", None),
                            }
                        )
                    )
                    typed_chunk: RetrievalSummaryJson = chunk  # type: ignore
                    request_info["answer"] = answer
                    request_info["sources"] = typed_chunk.get("sources", []) # type: ignore
                    request_info["filters"] = typed_chunk.get("filters")
                    request_info["llm_context"] = typed_chunk.get("llm_context")
                    request_info["generated_images"] = typed_chunk.get("generated_images", None)
                    if edited_user_profile := typed_chunk.get("user_profile"):
                        if isinstance(edited_user_profile, dict):
                            initial_user_profile.update(edited_user_profile) 
                    request_info["user_profile"] = flatten_user_profile_values(initial_user_profile)
                    finalize_answer(request_info)
    except Exception as e:
        # Handle exceptions gracefully
        logger.exception(f"Service Got exception:{e}")
        send(return_ok(data={"answer": LLM_API_ERROR}))
        if request_info:
            request_info["answer"] = LLM_API_ERROR
            finalize_answer(request_info)
        else:
            send_error_response()


def process_log_query(request: str):
    try:
        logger.debug("Client sent a message to log")
        request_json: Dict[str, Union[str, List[str]]] = json.loads(request)
        logger.debug(f"Payload from the front end is {request_json}")
        info = extract_request_info(request_json=request_json)
        record_conversation(user=get_auth_user(), info=info)
    except Exception as e:
        logger.exception(f"Failed to log query:{e}")
        send_error_response()


def setup_socketio_answer_event_handlers(socketio):
    """
    Registers event handlers for a socket.io server to handle client connections,
    messages, logging queries, disconnections, and connection errors.
    This function sets up the necessary event listeners for the socket.io server
    to interact with clients, for streaming answers.

    Parameters:
    - socketio: The socket.io server instance to which the event handlers are to be registered.

    Notes:
    - The event handlers include actions for client connect and disconnect events,
      message receipt, log query messages, and handling connection errors.
    """

    @socketio.on("connect")
    def handle_connect():
        before_request()
        logger.debug("Client connected")

    @socketio.on("message")
    def handle_message(msg):
        before_request()
        process_and_stream_answer(msg)

    @socketio.on("log_query")
    def handle_log_query(request: str):
        before_request()
        process_log_query(request)

    @socketio.on("disconnect")
    def handle_disconnect():
        before_request()
        logger.info("Client disconnected")
        # disconnect()

    @socketio.on("connect_error")
    def handle_error():
        logger.error("Failed to connect to client")
