import time
from types import SimpleNamespace
from typing import Any, Dict, List, Optional

from common.backend.constants import CONVERSATION_DEFAULT_NAME, MEDIA_CONVERSATION_START_TAG
from common.backend.models.base import ConversationType, LlmHistory, MediaSummary
from common.backend.utils.context_utils import LLMStepName, get_main_trace, init_user_trace
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.llm_utils import (
    get_alternative_llm,
    get_llm_capabilities,
    get_llm_completion,
    handle_response_trace,
)
from common.backend.utils.sql_timing import log_query_time
from common.llm_assist.logging import logger
from common.solutions.chains.docs.media_qa_chain import MediaQAChain
from common.solutions.chains.image_generation.image_generation_chain import ImageGenerationChain
from common.solutions.chains.image_generation.image_generation_decision_chain import ImageGenerationDecisionChain
from common.solutions.prompts.conversation_title import CONVERSATION_TITLE_PROMPT, TITLE_USER_PROMPT
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import DSSLLMCompletionQuery, DSSLLMCompletionResponse
from portal.backend.constants import AUG_RAG_TYPE
from portal.backend.db.conversations import CONVERSATION_DEFAULT_NAME
from portal.backend.models import AgentDetails, PortalConversationParams
from portal.backend.utils.agents_utils import filter_agents_per_user, map_agents_id_name, map_aug_llms_id_name
from portal.backend.utils.answers_utils import filter_answers_per_user, get_answers_info
from portal.solutions.agents_orchestrator import AgentsOrchestrator


class LLM_Question_Answering:
    """LLM_Question_Answering: A class to facilitate the question-answering process using LLM model"""

    def __init__(self, llm: DKULLM):
        # Constructor parameters:
        self.project = dataiku_api.default_project
        self.llm = llm
        self.decision_llm = get_alternative_llm("json_decision_llm_id")
        self.webapp_config = dataiku_api.webapp_config
        self.chat_has_media = False

    def get_available_models_by_type(
        self,
        user,
        selected_agents,
        selected_agents_description,
        agents_ids_names,
        type="agent",
    ) -> List[AgentDetails]:
        available_agents: List[AgentDetails] = []
        # TODO maybe base it only on the descriptions
        if selected_agents and selected_agents_description:
            agent_id_key = "augmented_llm" if type == AUG_RAG_TYPE else "agent_id"
            agent_desc_key = "augmented_llm_description" if type == AUG_RAG_TYPE else "agent_description"
            selected_agents_map = {
                agent[agent_id_key]: agent[agent_desc_key]
                for agent in selected_agents_description
                if agent.get(
                    agent_id_key
                )  # Need this check in case user removes an agent from list so id is removed but description stays in the config
            }
            user_accessible_agents = filter_agents_per_user(user, selected_agents)

            models_type = "retrieval augmented" if type == AUG_RAG_TYPE else "agents"
            logger.debug(f"User accessible {models_type}: {user_accessible_agents}", log_conv_id=True)
            available_agents = [
                {
                    "agent_id": ":".join(id.split(":")[1:3])
                    if len(id.split(":")) >= 3
                    else "invalid_agent_id",  # remove project key
                    "name": agents_ids_names[id],
                    "description": selected_agents_map.get(id),
                    "project_key": id.split(":")[0],
                }
                for id in user_accessible_agents
                if id in selected_agents_map
            ]
            logger.debug(f"Available {models_type}: {available_agents}", log_conv_id=True)
        return available_agents

    def get_available_aug_llms(self, user) -> List[AgentDetails]:
        selected_projects = self.webapp_config.get("projects_keys")
        selected_aug_llms = self.webapp_config.get("augmented_llms")
        selected_aug_llms_description = self.webapp_config.get("aug_llms_descriptions")
        aug_llms_ids_names = map_aug_llms_id_name(selected_projects) if selected_projects and selected_aug_llms else {}
        return self.get_available_models_by_type(
            user=user,
            type=AUG_RAG_TYPE,
            selected_agents=selected_aug_llms,
            selected_agents_description=selected_aug_llms_description,
            agents_ids_names=aug_llms_ids_names,
        )

    def get_available_agents(self, user) -> List[AgentDetails]:
        selected_projects = self.webapp_config.get("projects_keys")
        selected_agents = self.webapp_config.get("agents_ids")
        selected_agents_description = self.webapp_config.get("agents_descriptions")
        agents_ids_names = map_agents_id_name(selected_projects) if selected_projects and selected_agents else {}
        return self.get_available_models_by_type(
            user=user,
            selected_agents=selected_agents,
            selected_agents_description=selected_agents_description,
            agents_ids_names=agents_ids_names,
        )

    def get_available_answers(self, user, bypass_permission_check) -> List[AgentDetails]:
        available_answers: List[AgentDetails] = []
        selected_answers = self.webapp_config.get("answers_ids")
        selected_projects = self.webapp_config.get("projects_keys")

        if selected_answers:
            user_accessible_answers = (
                filter_answers_per_user(user, selected_answers) if not bypass_permission_check else selected_answers
            )
            logger.debug(f"User accessible answers: {user_accessible_answers}", log_conv_id=True)
            available_answers = [
                {
                    "agent_id": id,
                    "name": answer["webapp_name"],
                    "description": answer["webapp_description"],
                    "project_key": answer.get("project_key"),
                }
                for id, answer in get_answers_info(selected_projects).items()
                if id in user_accessible_answers
            ]
            logger.debug(f"Available answers: {available_answers}", log_conv_id=True)
        return available_answers

    def prepare_agents(self, user) -> List[AgentDetails]:
        logger.debug(f"answers ids config: {self.webapp_config.get('answers_ids')}")
        bypass_permission_check = self.webapp_config.get("bypass_permission_check", False) #It will only be used to bypass answers webapps checks as they require admin rights
        if bypass_permission_check:
            logger.info("Bypassing permission checks for Answers webapps") 
        available_answers: List[AgentDetails] = self.get_available_answers(user, bypass_permission_check)
        available_agents: List[AgentDetails] = self.get_available_agents(user)
        available_aug_llms: List[AgentDetails] = self.get_available_aug_llms(user)
        return available_answers + available_agents + available_aug_llms

    def run_with_agents(self, conversation_params: PortalConversationParams, user: str):
        available_agents: List[AgentDetails] = self.prepare_agents(user)
        if available_agents:
            yield {"step": SimpleNamespace(name="Assessing available agents before proceeding...")}
        agents_orchestrator = AgentsOrchestrator(self.llm, self.decision_llm, available_agents, user)
        for resp in agents_orchestrator.run(conversation_params):
            yield resp

    def get_answer_and_sources(  # noqa: PLR0917 too many positional arguments
        self,
        query: str,
        conversation_type: ConversationType,
        chat_history: List[LlmHistory] = [],
        chain_type: Optional[str] = None,
        media_summaries: Optional[List[MediaSummary]] = None,
        previous_media_summaries: Optional[List[MediaSummary]] = None,
        user_profile: Optional[Dict[str, Any]] = None,
        user: Optional[str] = None,
        app_id: Optional[str] = None,
        user_agent: Optional[str] = None,
        agents_files_uploads: Optional[Dict[str, Dict[str, MediaSummary]]] = None,
        conversation_id: Optional[str]=None
    ) -> Any:
        logger.debug(f"Time ===>: starting tracking time: Generating response", log_conv_id=True)
        if query != MEDIA_CONVERSATION_START_TAG:
            init_user_trace(LLMStepName.DKU_AGENT_CONNECT_QUERY.name)
            if not (main_trace := get_main_trace()):
                raise Exception("Main trace is not initialized correctly.")
            main_trace.attributes["query"] = query

        conversation_params: PortalConversationParams = {
            "user_query": query,
            "chat_history": chat_history,
            "chain_type": chain_type,
            "media_summaries": media_summaries,
            "previous_media_summaries": previous_media_summaries,
            "global_start_time": time.time(),  # Capture start time
            "justification": "",
            "user_profile": user_profile,
            "self_service_decision": None,
            "retrieval_enabled": False,
            "app_id": app_id,
            "user_agent": user_agent,
            "agents_files_uploads": agents_files_uploads,
            "conversation_id": conversation_id
        }
        logger.debug(f"Conversation type is set to: {conversation_type}, chain type is set to: {chain_type}", log_conv_id=True)
        logger.debug(f"Selected agents description: {self.webapp_config.get('agents_descriptions')}", log_conv_id=True)
        response = None
        if conversation_type is not None:
            logger.debug(f"conversation_type is set to : {conversation_type}", log_conv_id=True)

        self.chat_has_media = any([item.get("output","") == "generated_media_by_ai" for item in conversation_params.get("chat_history", [])])
        llm_capabilities = get_llm_capabilities()
        response = None
        # image generation should not be run when a media conversation is started because there is no user query
        if llm_capabilities.get("image_generation", False) and query != MEDIA_CONVERSATION_START_TAG:
            logger.debug("Image generation is enabled")
            img_system_prompt = self.webapp_config.get("image_generation_system_prompt")
            img_gen_decision_chain = ImageGenerationDecisionChain(
                llm=self.decision_llm, system_prompt=img_system_prompt
            )
            decision_output = img_gen_decision_chain.get_decision_as_json(user_query=query, chat_history=chat_history)
            logger.debug(f"Image generation decision chain output: {decision_output}", log_conv_id=True)
            generate_image = decision_output.get("decision")
            generated_query = decision_output.get("query")
            referred_image = decision_output.get("referred_image")
            if generate_image:
                image_generation_llm: Optional[DKULLM] = None
                if image_generation_llm_id := dataiku_api.webapp_config.get("image_generation_llm_id", None):
                    image_generation_llm = DKULLM(llm_id=image_generation_llm_id)
                else:
                    error_message = f"Image generation LLM ID is not set"
                    logger.error(error_message, log_conv_id=True)
                    raise ValueError(error_message)
                if not (main_trace := get_main_trace()):
                    error_message = f"Main trace is not initialized correctly"
                    logger.error(error_message, log_conv_id=True)
                    raise Exception(error_message)
                sub_trace = main_trace.subspan(LLMStepName.IMAGE_GENERATION.name)
                max_images_to_generate = int(dataiku_api.webapp_config.get("max_images_per_user_per_week", 0))
                from portal.backend.db.user_profile import user_profile_sql_manager
                include_user_profile_in_prompt = bool(self.webapp_config.get("include_user_profile_in_prompt", False)) 

                response = ImageGenerationChain(
                    llm=self.llm,
                    image_generation_llm=image_generation_llm,
                    user_query=generated_query,
                    referred_image=referred_image,
                    user_profile_sql_manager=user_profile_sql_manager,
                    user_profile=user_profile,
                    trace=sub_trace,
                    include_user_profile_in_prompt=include_user_profile_in_prompt
                ).run_image_generation_query(max_images_to_generate)
                return response
            # If image generation is not required, continue with the normal flow and ignore media in the user profile
            if user_profile and user_profile.get("media"):
                user_profile.pop("media", None)
            if user_profile and user_profile.get("generated_media_info"):
                user_profile.pop("generated_media_info", None)
        if query == MEDIA_CONVERSATION_START_TAG:
            logger.debug(f"Starting media QA conversation", log_conv_id=True)
            response = MediaQAChain(media_summaries).start_media_qa_chain(user_profile, False if chat_history else True)
            logger.debug(f"Media QA response: {response}", log_conv_id=True)
        else:
            response = self.run_with_agents(conversation_params, user or "unknown")
        return response

    @staticmethod
    @log_query_time
    def get_conversation_title(query: str, answer: str, user_profile: Optional[Dict[str, Any]] = None) -> str:
        title_llm: DKULLM = get_alternative_llm("title_llm_id")
        system_prompt = CONVERSATION_TITLE_PROMPT.format(query=query, user_profile= user_profile)
        user_prompt = TITLE_USER_PROMPT.format(generated_content=answer)
        completion: DSSLLMCompletionQuery = get_llm_completion(title_llm)
        completion.with_message(system_prompt, role="system")
        completion.with_message(user_prompt, role="user")
        conversation_title = CONVERSATION_DEFAULT_NAME
        try:
            resp: DSSLLMCompletionResponse = completion.execute()
            handle_response_trace(resp)
            conversation_title = str(resp.text) if resp.text else CONVERSATION_DEFAULT_NAME
            if error_message := resp._raw.get("errorMessage"):
                logger.error(error_message, log_conv_id=True)
            return conversation_title
        except Exception as e:
            logger.exception(f"Error when calling LLM API: {e}.", log_conv_id=True)
            return CONVERSATION_DEFAULT_NAME
