from typing import Any, Dict, Generator, List, Optional

import dataiku
from common.backend.constants import MEDIA_CONVERSATION_START_TAG
from common.backend.models.base import LLMContext, MediaSummary, RetrievalSummaryJson
from common.backend.utils.context_utils import (
    LLMStepName,
    add_llm_step_trace,
    get_main_trace,
    init_user_trace,
)
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger


class MediaQAChain:
    def __init__(
        self,
        media_summaries: Optional[List[MediaSummary]] = None,
    ):
        logger.debug(f"Media QA Chain initialized {media_summaries}", log_conv_id=True)
        self._media_summaries: Optional[List[MediaSummary]] = media_summaries

    @property
    def media_summaries(self) -> Optional[List[MediaSummary]]:
        return self._media_summaries

    @property
    def chain_purpose(self) -> str:
        return LLMStepName.START_MEDIA_QA_CONVERSATION.value

    def get_as_json(
        self,
        user_profile: Optional[Dict[str, Any]] = None,
        is_new: Optional[bool] = True,
    ) -> RetrievalSummaryJson:
        llm_context: LLMContext = {}  # type: ignore
        if user_profile:
            llm_context["user_profile"] = user_profile
        if self.media_summaries:
            if is_new:
                # This is for the case when the user starts a new conversation with docs
                llm_context["media_qa_context"] = [
                    {
                        "original_file_name": summary.get("original_file_name"),
                        "metadata_path": summary.get("metadata_path"),
                        "file_path": summary.get("file_path"),
                        "chain_type": summary.get("chain_type"),
                        "topics": summary.get("topics"),
                    }
                    for summary in self.media_summaries
                ]
            else:
                # This is for the case when the user uploads docs during an existing conversation
                # TODO consider removing media_qa_context and just use uploaded_docs
                logger.debug(f"media_qa_context already exists in llm_context", log_conv_id=True)
                llm_context["uploaded_docs"] = [
                    {
                        "original_file_name": summary.get("original_file_name"),
                        "metadata_path": summary.get("metadata_path"),
                        "file_path": summary.get("file_path"),
                        "chain_type": summary.get("chain_type"),
                        "topics": summary.get("topics"),
                    }
                    for summary in self.media_summaries
                ]

        response: RetrievalSummaryJson = {
            "answer": "",
            "sources": [],
            "filters": None,
            "knowledge_bank_selection": [],
            "llm_context": llm_context,
        }
        return response

    def read_llm_steps_from_json(self):
        folder_id: Optional[str] = dataiku_api.webapp_config.get("upload_folder")
        if folder_id is None:
            raise Exception("An upload document folder is required. Please add in the edit tab.")
        folder = dataiku.Folder(folder_id)
        if self.media_summaries is None:
            raise Exception("Expected media summaries but none is present.")
        for summary in self.media_summaries:
            metadata_path = summary.get("metadata_path")  # type: ignore
            if not metadata_path:
                logger.error(f"metadata_path is not provided for document", log_conv_id=True)
                raise Exception("metadata_path is not provided for document")
            extract_summary = folder.read_json(metadata_path)
            trace: Optional[List[Dict[str, Any]]] = extract_summary.get("trace", [])
            begin_time = extract_summary.get("begin_time", 0)
            if not get_main_trace():
                    init_user_trace(LLMStepName.DKU_ANSWERS_UPLOAD_FILE.name, begin_time)
                    main_trace = get_main_trace()
                    main_trace.attributes["query"] = MEDIA_CONVERSATION_START_TAG
            if trace:
                add_llm_step_trace(trace)

    def start_media_qa_chain(
        self,
        user_profile: Optional[Dict[str, Any]] = None,
        is_new: Optional[bool] = True,
    ) -> Generator[RetrievalSummaryJson, Any, None]:
        self.read_llm_steps_from_json()
        logger.debug(f"Running media qa chain", log_conv_id=True)
        yield self.get_as_json(user_profile, is_new)