from concurrent.futures import ThreadPoolExecutor, as_completed
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple, Union

from common.backend.constants import KEYS_TO_REMOVE_FROM_LOGS
from common.backend.models.base import LlmHistory, LLMStep, MediaSummary, RetrievalSummaryJson
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.json_utils import mask_keys
from common.backend.utils.llm_utils import append_summaries_to_completion_msg, handle_response_trace
from common.backend.utils.prompt_utils import append_user_profile_to_prompt
from common.llm_assist.logging import logger
from common.solutions.chains.generic_answers_chain import GenericAnswersChain
from common.solutions.chains.no_retrieval_chain import NoRetrievalChain
from dataiku.langchain.dku_llm import DKULLM
from dataikuapi.dss.llm import (
    DSSLLMCompletionQuery,
    DSSLLMCompletionQueryMultipartMessage,
    DSSLLMStreamedCompletionChunk,
    DSSLLMStreamedCompletionFooter,
)
from portal.backend.constants import AGENT_ID, AUG_LLM_ID
from portal.backend.models import (
    AgentDetails,
    AgentGeneratedAnswer,
    AgentQuery,
    AgentsSelection,
    PortalConversationParams,
    PortalLLMContext,
    PortalSource,
    ToolCall,
)
from portal.backend.utils.agents_utils import add_agent_uploads
from portal.backend.utils.llm_utils import get_llm_completion
from portal.backend.utils.sources_utils import process_agents_sources, process_answers_sources
from portal.solutions.chains.agent_query_builder import AgentQueryBuilder
from portal.solutions.chains.query_resolver_chain import QueryResolverChain


def is_agent(agent_id):
    return agent_id.startswith(AGENT_ID) or agent_id.startswith(AUG_LLM_ID)

class AgentsOrchestrator:
    def __init__(
        self, llm: DKULLM, decision_llm: DKULLM, available_agents: List[AgentDetails], user: str, chat_has_media=False
    ):
        self.llm = llm
        self.decision_llm = decision_llm
        self.available_agents = available_agents
        self.user = user
        self.chat_has_media = chat_has_media

    def prepare_agents_queries(
        self,
        user_query: str,
        chat_history: List[LlmHistory] = [],
        user_profile: Optional[Dict[str, Any]] = None,
    ) -> Tuple[List[AgentQuery], str, bool]:
        queries_builder_chain = AgentQueryBuilder(
            llm=self.decision_llm,
            chat_history=chat_history,
            agents=self.available_agents,
            user_profile=user_profile,
            chat_has_media=False,
        )
        queries_builder_decision = queries_builder_chain.get_decision_as_json(
            user_query=user_query, chat_history=chat_history
        )
        invalid_agent_id_generated = queries_builder_decision.get("invalid_agent_id_generated", False)

        queries: List[AgentQuery] = queries_builder_decision.get("queries")
        justification = queries_builder_decision.get("justification")
        logger.debug(f"queries: {queries}", log_conv_id=True)
        return queries, justification, invalid_agent_id_generated

    def append_history_for_agents_completion(
        self, completion: DSSLLMCompletionQuery, chat_history: List[LlmHistory] = [], summaries: List[MediaSummary] = []):
        logger.info(f"Appending history for agents completion", log_conv_id=True)
        for item in chat_history:
            completion.with_message(message=item["input"], role="user")
            completion.with_message(message=item["output"], role="assistant")
        if summaries:
            msg: DSSLLMCompletionQueryMultipartMessage = completion.new_multipart_message(role="user")
            append_summaries_to_completion_msg(summaries, msg)

    def run_default_chain(self, conversation_params: PortalConversationParams):
        yield {"step": LLMStep.CALLING_DEFAULT_CHAIN}
        include_user_profile_in_prompt = bool(dataiku_api.webapp_config.get("include_user_profile_in_prompt", False))
        qa_chain = NoRetrievalChain(self.llm, include_user_profile_in_prompt=include_user_profile_in_prompt)
        for resp in qa_chain.run_completion_query(conversation_params):
            yield resp

    def finalize_agent_completion(
        self,
        query: AgentQuery,
        data: Any,
        agent: AgentDetails,
        query_builder_justification: str,
        media_summaries: Optional[List[MediaSummary]],
    ):
        sources: List[PortalSource] = []
        items: List[ToolCall] = []
        additional_information = data.get("additionalInformation")
        if additional_information and additional_information.get("sources"):
            items = process_agents_sources(additional_information["sources"]) # type: ignore

        elif data.get("usedRetrieval", {}).get("sources"):
            used_tool = data["usedRetrieval"]
            items = process_answers_sources(used_tool["sources"], used_tool) # type: ignore
        # TODO should we handle aug llm differently
        is_dss_agent = is_agent(agent["agent_id"])
        sources.append(
            {
                "name": agent["name"],
                "id": agent["agent_id"],
                "type": "agent" if is_dss_agent else "answers webapp",
                "items": items,
                "answer": data,
            }
        )
        logger.debug(f"Sources: {mask_keys(sources, KEYS_TO_REMOVE_FROM_LOGS)}")
        llm_context: PortalLLMContext = {}
        llm_context["agents_selection"] = AgentsSelection(
            calls=[{"agent_id": agent["agent_id"], "query": query["query"]}],
            justification=query_builder_justification,
        )
        uploads = data.get("media_summaries")
        new_agents_uploads: Dict[str, Dict[str, MediaSummary]] = {}
        if media_summaries:
            if not new_agents_uploads:
                new_agents_uploads = {}
            for media_summary in media_summaries:
                new_agents_uploads[media_summary["original_file_name"]] = {}

        if uploads:
            llm_context["agents_files_uploads"] = add_agent_uploads(
                new_agents_uploads,
                uploads,
                agent["agent_id"],
            )
        if media_summaries:
            llm_context["uploaded_docs"] = media_summaries

        return RetrievalSummaryJson(
            answer=data.get("answer", ""),
            sources=sources,
            generated_images=data.get("generatedMedia", []),
            filters=[],  # type: ignore
            knowledge_bank_selection=[],
            llm_context=llm_context,
            user_profile={},
        )

    def run_agent_completion(
        self,
        sub_query: AgentQuery,
        query: str,
        conversation_params: PortalConversationParams,
        media_summaries: List[MediaSummary],
        summaries: List[MediaSummary],
        query_builder_justification: str,
        agents_files_uploads: Optional[Dict[str, Dict[str, MediaSummary]]],
        chat_history: List[LlmHistory],
        user_profile: Optional[Dict[str, Any]],
    ):
        logger.info(f"Running agent completion for {sub_query}", log_conv_id=True)
        # yield {"step": LLMStep.CALLING_AGENT}

        agent_id = sub_query["agent_id"]

        is_dss_agent = is_agent(agent_id=agent_id)
        sub_query["query"] = query
        # agent_query = sub_query.get("query")
        agent = next(
            (agent for agent in self.available_agents if agent.get("agent_id") == agent_id), None
        ) or AgentDetails(
            {"agent_id": agent_id, "name": "No agent name provided", "description": "", "project_key": ""}
        )
        yield {"step": SimpleNamespace(name=f"Selected agent: [{agent.get('name')}]")}
        completion = self._create_agent_completion(
            agent_id=agent_id,
            agent_query=query,
            user_profile=user_profile,
            chat_history=chat_history,
            conversation_params=conversation_params,
            summaries=summaries,
            agents_files_uploads=agents_files_uploads,
            is_dss_agent=is_dss_agent,
        )
        chunk = None
        try:
            for chunk in completion.execute_streamed():
                yield chunk
                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    handle_response_trace(chunk)
            if chunk:
                yield self.finalize_agent_completion(
                    query=sub_query,
                    data=chunk.data,
                    agent=agent,
                    query_builder_justification=query_builder_justification,
                    media_summaries=media_summaries,
                )
        except Exception as e:
            logger.exception(f"Error in agent completion: {e}", log_conv_id=True)
            message = "Error while processing your request"
            yield DSSLLMStreamedCompletionChunk({"text": message})
            yield DSSLLMStreamedCompletionFooter({"finishReason": "error", "errorMessage": message})

    def run_agents_completion(
        self,
        queries: List[AgentQuery],
        query: str,
        conversation_params: PortalConversationParams,
        media_summaries: List[MediaSummary],
        summaries: List[MediaSummary],
        query_builder_justification: str,
    ):
        qa_chain: GenericAnswersChain
        chat_history = conversation_params.get("chat_history", [])
        user_profile = conversation_params.get("user_profile", {})
        generated_answers: List[AgentGeneratedAnswer] = []
        agents_files_uploads = conversation_params.get("agents_files_uploads")
        new_agents_uploads: Dict[str, Dict[str, MediaSummary]] = {}
        if media_summaries:
            if not new_agents_uploads:
                new_agents_uploads = {}
            for media_summary in media_summaries:
                new_agents_uploads[media_summary["original_file_name"]] = {}

        def call_agent(sub_query: AgentQuery):
            """
            This function will be executed in a separate thread.
            Returns (sub_query, answer_object) so we can handle the result later.
            """
            agent_id = sub_query["agent_id"]
            is_dss_agent = is_agent(agent_id=agent_id)
            agent_query = sub_query.get("query") or query

            completion = self._create_agent_completion(
                agent_id=agent_id,
                agent_query=agent_query,
                user_profile=user_profile,
                chat_history=chat_history,
                conversation_params=conversation_params,
                summaries=summaries,
                agents_files_uploads=agents_files_uploads,
                is_dss_agent=is_dss_agent,
            )
            # Actually execute the agent call
            answer = completion.execute()
            return (sub_query, answer)

        selected_agent_names = []
        for q in queries:
            q_agent_id = q["agent_id"]
            agent = next((agent for agent in self.available_agents if agent.get("agent_id") == q_agent_id), None)
            if agent:
                selected_agent_names.append(agent.get("name"))
        yield {"step": SimpleNamespace(name=f"Selected agents: {selected_agent_names}")}
        futures = {}
        with ThreadPoolExecutor(max_workers=len(queries)) as executor:
            for sub_query in queries:
                agent_id = sub_query["agent_id"]
                future = executor.submit(call_agent, sub_query)
                futures[future] = sub_query
        logger.debug(f"Waiting for agents workers", log_conv_id=True)
        for future in as_completed(futures):
            sub_query = futures[future]
            agent_id = sub_query["agent_id"]
            is_dss_agent = is_agent(agent_id=agent_id)
            is_answers = agent_id.startswith("answer")
            agent_query = sub_query.get("query") or query

            try:
                _, answer = future.result()  # (sub_query, answer)
            except Exception as e:
                # Handle errors from the agent call
                logger.exception(f"Error calling agent {agent_id}: {e}", log_conv_id=True, exc_info=True)
                generated_answers.append(
                    {
                        "sub_query": agent_query,
                        "agent_answer": "Error while processing your request",
                        "agent_id": agent_id,
                        "agent_name": "No agent name provided",
                        "sources": [],
                        "used_tool": None,
                    }
                )
                continue
            generated_answer = ""
            sources = []
            used_tool = None
            if not answer or not answer.success:
                generated_answer = "Error while processing your request"
            elif is_answers and answer.json:
                answer_json = answer.json
                logger.debug(f"answer: {mask_keys(answer_json, KEYS_TO_REMOVE_FROM_LOGS)}", log_conv_id=True)  # TODO remove?
                add_agent_uploads(
                    new_agents_uploads,
                    answer_json.get("data", {}).get("media_summaries", []),
                    agent_id,
                )
                generated_answer = answer_json.get("data", {}).get("answer")
                used_retrieval = answer_json.get("data", {}).get("usedRetrieval", {})
                sources = used_retrieval.get("sources")
                handle_response_trace(answer_json.get("data", {}))
                if sources:
                    used_tool = used_retrieval
            elif is_dss_agent and answer._raw:
                generated_answer = answer.text
                additional_information = answer._raw.get("additionalInformation")
                if additional_information and additional_information.get("sources"):
                    sources = additional_information["sources"]
                handle_response_trace(answer)
            else:
                generated_answer = "Error while processing your request"

            agent = next((agent for agent in self.available_agents if agent.get("agent_id") == agent_id), None)
            generated_answers.append(
                {
                    "agent_id": agent_id,
                    "agent_name": agent["name"] if agent else "No agent name provided",
                    "sub_query": agent_query,
                    "agent_answer": generated_answer,
                    "sources": sources,
                    "used_tool": used_tool,
                }
            )
        logger.debug(f"generated_answers: {mask_keys(generated_answers, KEYS_TO_REMOVE_FROM_LOGS)}", log_conv_id=True)

        conversation_params["sources"] = generated_answers  # type: ignore
        conversation_params["media_summaries"] = media_summaries
        logger.debug(f"agents_files_uploads: {mask_keys(new_agents_uploads, KEYS_TO_REMOVE_FROM_LOGS)}", log_conv_id=True)
        yield {"step": LLMStep.SYNTHESIZING_AGENTS_ANSWERS}
        qa_chain = QueryResolverChain(
            self.llm,
            generated_queries_answers=generated_answers,
            agents=self.available_agents,
            agents_queries=queries,
            chat_has_media=self.chat_has_media,
            agents_files_uploads=new_agents_uploads,
            uploaded_files=media_summaries,
            query_builder_justification=query_builder_justification,
        )
        for resp in qa_chain.run_completion_query(conversation_params):
            yield resp

    def get_agent_customization(self, agent_id: str) -> Optional[Dict[str, str]]:
        agents_customization = dataiku_api.webapp_config.get("agents_customization", {})
        if agents_customization:
            logger.debug(f"agents_customization: {agents_customization}")
            agent_customization = next(
                (agent for agent in agents_customization if agent.get("agent_id") == agent_id), None
            )
            logger.debug(f"agent_customization: {agent_customization}")
            return agent_customization
        return None

    def _create_agent_completion(
        self,
        agent_id: str,
        agent_query: str,
        user_profile: Union[Dict[str, Any], None],
        chat_history: List[LlmHistory],
        conversation_params: PortalConversationParams,
        summaries: List[MediaSummary],
        agents_files_uploads: Optional[Dict[str, Dict[str, MediaSummary]]],
        is_dss_agent: bool,
    ) -> DSSLLMCompletionQuery:
        
        agent = next((agent for agent in self.available_agents if agent.get("agent_id") == agent_id), None)
        if not agent:
            raise ValueError(f"Agent{agent_id} not found")
        conversation_id = conversation_params.get("conversation_id", "")
        completion: DSSLLMCompletionQuery = get_llm_completion(agent_id, self.user, agent.get("project_key", ""), conversation_id)
        completion.settings["user"] = self.user
        completion.settings["user_profile"] = user_profile
        completion.settings["chat_history"] = chat_history
        completion.settings["app_id"] = conversation_params.get("app_id")
        completion.settings["user_agent"] = conversation_params.get("user_agent")
        if is_dss_agent:
            # DSS agent
            agent_customization = self.get_agent_customization(f"{agent.get('project_key', '')}:{agent_id}")
            additional_prompt = agent_customization.get("agent_prompt", "") + " " if agent_customization else ""
            disable_upload = agent_customization.get("disable_upload", False) if agent_customization else False
            logger.debug(f"Agent [{agent_id}] additional prompt: {additional_prompt}", log_conv_id=True)
            system_prompt = append_user_profile_to_prompt(
                system_prompt=additional_prompt,
                user_profile=user_profile,
                include_full_user_profile=dataiku_api.webapp_config.get("include_user_profile_in_prompt", False),
            )
            completion.with_message(message=system_prompt, role="system")
            summaries_to_forward = summaries
            if disable_upload:
                logger.info(f"File upload is disabled for agent {agent.get('project_key', '')}:{agent_id}", log_conv_id=True)
                summaries_to_forward = []
            self.append_history_for_agents_completion(completion, chat_history, summaries_to_forward)
        else:
            completion.settings["media_summaries"] = summaries
            completion.settings["agents_files_uploads"] = agents_files_uploads
        completion.with_message(agent_query)
        return completion

    def build_agents_queries(self, query: str, max_retries=3, user_profile=None, chat_history=None):
        from common.backend.utils.dataiku_api import dataiku_api

        invalid_agent_id_generated: bool = True
        queries: List[AgentQuery] = []
        justification: str = ""
        retry_attempt = 0
        if len(self.available_agents) == 1 and dataiku_api.webapp_config.get("skip_decision_chain", False):
            # Skip the decision chain and use the only configured agent
            queries = [
                {
                    "agent_id": self.available_agents[0]["agent_id"],
                    "query": query,
                }
            ]
            justification = "Skipping decision chain, using the one configured agent"
            invalid_agent_id_generated = False
        else:
            while retry_attempt < max_retries and invalid_agent_id_generated:
                (queries, justification, invalid_agent_id_generated) = self.prepare_agents_queries(
                    user_query=query,
                    chat_history=chat_history,
                    user_profile=user_profile,
                )
                if invalid_agent_id_generated:
                    retry_attempt += 1
                    logger.warn(
                        f"Invalid agent IDs were generated. Retrying... (Attempt {retry_attempt}/{max_retries})", log_conv_id=True
                    )
        return queries, justification, invalid_agent_id_generated

    def run(self, conversation_params: PortalConversationParams, max_retries=3):
        chat_history = conversation_params.get("chat_history", [])
        user_profile = conversation_params.get("user_profile", {})
        try:
            query = conversation_params.get("user_query", "")
            media_summaries: List[MediaSummary] = conversation_params.get("media_summaries") or []
            previous_media_summaries: List[MediaSummary] = conversation_params.get("previous_media_summaries") or []
            summaries = (media_summaries or []) + (previous_media_summaries or [])
            agents_files_uploads = conversation_params.get("agents_files_uploads")
            if not self.available_agents:
                logger.info(f"No available agents, we use the default chain", log_conv_id=True)
                return self.run_default_chain(conversation_params)

            queries, justification, invalid_agent_id_generated = self.build_agents_queries(
                query=query, max_retries=max_retries, user_profile=user_profile, chat_history=chat_history,
            )
            if queries:
                if invalid_agent_id_generated:
                    logger.warn(f"Max retries reached. Proceeding with possibly filtered queries", log_conv_id=True)
                if len(queries) > 1:
                    # Sequentially call agents.
                    return self.run_agents_completion(
                        queries=queries,
                        query=query,
                        conversation_params=conversation_params,
                        media_summaries=media_summaries,
                        summaries=summaries,
                        query_builder_justification=justification,
                    )
                else:
                    sub_query = queries[0]
                    # In case there is a single agent, we can stream directly the response back to portal.
                    return self.run_agent_completion(
                        sub_query=sub_query,
                        query=query,
                        conversation_params=conversation_params,
                        media_summaries=media_summaries,
                        summaries=summaries,
                        agents_files_uploads=agents_files_uploads,
                        query_builder_justification=justification,
                        user_profile=user_profile,
                        chat_history=chat_history,
                    )

            else:
                logger.info(f"No agent was called, we use the default chain", log_conv_id=True)
                return self.run_default_chain(conversation_params)
        except Exception as e:
            logger.exception(f"Error in AgentsOrchestrator: {e}", log_conv_id=True)
            include_user_profile_in_prompt = bool(dataiku_api.webapp_config.get("include_user_profile_in_prompt", False))
            response = NoRetrievalChain(self.llm, include_user_profile_in_prompt=include_user_profile_in_prompt).get_as_json("Error in generating answer", user_profile=user_profile)
            gen = (x for x in [response])
            return gen
