from typing import Dict, List, Optional, cast

from common.backend.constants import KEYS_TO_REMOVE_FROM_LOGS
from common.backend.models.base import (
    ConversationType,
    ExtractedQueryInfo,
    GeneratedMedia,
    LLMContext,
    MediaSummary,
    QuestionData,
)
from common.backend.utils.auth_utils import get_auth_user
from common.backend.utils.context_utils import get_main_trace_dict
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.image_generation_utils import handle_images
from common.backend.utils.json_utils import mask_keys
from common.backend.utils.rag_utils import rm_image_data_from_sources_to_store
from common.llm_assist.logging import logger
from common.solutions.chains.title.media_conversation_title import SummaryTitler
from portal.backend.db.conversations import conversation_sql_manager
from portal.backend.db.user_profile import user_profile_sql_manager
from portal.backend.models import PortalSource
from portal.solutions.service import MEDIA_CONVERSATION_START_TAG, LLM_Question_Answering


def record_conversation(info: ExtractedQueryInfo):
    """
    Extracts data from a JSON request for DB updating purposes. It also includes additional details
    like conversation name and knowledge bank identifiers when necessary.

    Parameters:
    - info (ExtractedQueryInfo): The JSON data of the request with potential keys for query, filters, answer, etc.

    Notes:
    - Successfully updates the logging database with new interaction information and confirms with a success message.
    - Used to track user interactions and LLM responses during specific events such as new queries or at conversation closure.
    """
    try:
        auth_identifier = get_auth_user()
        logger.debug(f"auth_identifier: {auth_identifier}")
        query = info["query"] if info["query"] != MEDIA_CONVERSATION_START_TAG else ""
        answer = info["answer"]
        conversation_name: Optional[str] = info["conversation_name"]
        conversation_id: Optional[str] = info["conversation_id"]
        is_new_conversation: Optional[bool] = info["is_new_conversation"]
        conversation_type: Optional[str] = info["conversation_type"]
        llm_context: LLMContext = info.get("llm_context") or {}
        logger.debug(f"llm_context: {mask_keys(llm_context, KEYS_TO_REMOVE_FROM_LOGS)}")
        if is_new_conversation:
            if conversation_type == ConversationType.GENERAL:
                conversation_name = LLM_Question_Answering.get_conversation_title(
                    info["query"], answer, info.get("user_profile", None)
                )
            elif conversation_type == ConversationType.MEDIA_QA:
                media_qa_context: List[MediaSummary] = llm_context.get("media_qa_context", [])
                conversation_name = SummaryTitler().generate_summary_title(media_qa_context)
            else:
                raise Exception("Unknown Conversation type.")

        llm_context["trace"] = get_main_trace_dict(query, answer)
        # Get llm id
        config: Dict[str, str] = dataiku_api.webapp_config
        llm_id = config.get("llm_id", None)  # Need to check
        images: List[GeneratedMedia] = []
        user_profile = info.get("user_profile", {})
        generated_images = info.get("generated_images", [])
        if generated_images:
            images, user_profile = handle_images(
                generated_images, auth_identifier, user_profile, user_profile_sql_manager
            )
        sources = []
        for source in cast(List[PortalSource], info["sources"]):
            for item in source.get("items") or []:
                if item.get("items"):
                    item["items"] = rm_image_data_from_sources_to_store(item["items"]) #type: ignore
            sources.append(source)
        # Update cache conversation
        new_history_record = QuestionData(
            id=str(len(info["history"])),
            query=query,
            # filters=info["filters"],
            answer=answer,
            sources=sources, #type: ignore
            feedback=None,
            llm_context=llm_context,
            generated_media={"images": images},
        )
        record_id, conversation_infos = conversation_sql_manager.add_record(
            dict(new_history_record),
            auth_identifier,
            conversation_id,
            conversation_name,
            is_new_conversation,
            llm_id,
        )
        return {
            "record_id": record_id,
            "conversation_infos": conversation_infos,
            "images": images,
            "user_profile": user_profile,
            "new_record": new_history_record,
        }

    except Exception as e:
        logger.exception(f"Error storing conversation in DB {e}")
        return None
