import asyncio
import json
import threading
import time
import uuid
from copy import deepcopy
from typing import Any, Dict, List, Optional

import dataiku
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter

from backend.models.events import EventKind
from backend.utils.conv_utils import get_selected_agents_as_objs, normalise_stream_event
from backend.utils.llm_utils import add_history_to_completion
from backend.utils.logging_utils import extract_error_message, get_logger

logger = get_logger(__name__)


class AgentResult:
    def __init__(self, agent_id: str, agent_query: Dict[str, Any], agent_name: str):
        self.agent_id = agent_id
        self.agent_query = agent_query
        self.agent_name = agent_name
        self.answer_text = ""
        self.sources: List[Dict[str, Any]] = []
        self.success = True
        self.error_message: Optional[str] = None
        self.artifacts = []
        self.tool_validation_requests: List[Dict[str, Any]] = []
        self.memory_fragment: Optional[Dict[str, Any]] = None

    def to_generated_answer(self) -> Dict[str, Any]:
        return {
            "agent_id": self.agent_id,
            "agent_name": self.agent_name,
            "agent_query": self.agent_query,
            "agent_answer": self.answer_text
            if self.success
            else (self.error_message or "Error while processing your request"),
        }


class OrchestratorService:
    # ---------- Utility helpers ----------
    @staticmethod
    def create_agent_completion(
        agent_id,
        query,
        messages,
    ):
        """Create agent completion from message history.
        
        Filters messages to extract only data relevant to this specific agent:
        - Regular conversation history
        - This agent's memory fragment (by agentId)
        - This agent's tool validation requests (by agentId, with metadata stripped)
        - Validation responses corresponding to this agent's requests only
        """
        from backend.utils.llm_utils import add_completion_msgs
        
        client = dataiku.api_client()
        project_key = agent_id.split(":")[0]
        project = client.get_project(project_key)
        llm = project.get_llm("agent:" + agent_id.split(":")[2])
        completion = llm.new_completion()
        
        # Filter out all non-regular messages
        regular_messages = [msg for msg in messages if msg.get("role") not in ["memoryFragment", "toolValidationRequests", "toolValidationResponses"]]
        
        # Extract agent-specific data from special messages
        filtered_special_messages = []
        agent_request_ids = set()  # Track this agent's validation request IDs
        
        for msg in messages:
            role = msg.get("role")
            
            if role == "memoryFragment":
                # Extract memory fragment for this agent
                memory_fragment_data = msg.get("memoryFragment", {})
                if "messages" in memory_fragment_data:
                    for mem_msg in memory_fragment_data["messages"]:
                        target = mem_msg.get("memoryFragmentTarget", {})
                        if target.get("agentId") == agent_id:
                            filtered_special_messages.append({
                                "role": "memoryFragment",
                                "memoryFragment": mem_msg.get("memoryFragment"),
                            })
                            break
            
            elif role == "toolValidationRequests":
                # Extract and filter tool validation requests for this agent
                all_requests = msg.get("toolValidationRequests", [])
                agent_requests = [req for req in all_requests if req.get("agentId") == agent_id]
                
                if agent_requests:
                    # Track request IDs for filtering responses later
                    for req in agent_requests:
                        if "id" in req:
                            agent_request_ids.add(req["id"])
                    
                    # Remove agentId/agentName metadata before sending to agent
                    cleaned_requests = [{k: v for k, v in req.items() if k not in ["agentId", "agentName"]} for req in agent_requests]
                    filtered_special_messages.append({
                        "role": "toolValidationRequests",
                        "toolValidationRequests": cleaned_requests,
                    })
            
            elif role == "toolValidationResponses":
                # Only include responses for this agent's requests
                if agent_request_ids:
                    all_responses = msg.get("toolValidationResponses", [])
                    # Filter responses to only those matching this agent's request IDs
                    agent_responses = [resp for resp in all_responses if resp.get("validationRequestId") in agent_request_ids]
                    
                    if agent_responses:
                        filtered_special_messages.append({
                            "role": "toolValidationResponses",
                            "toolValidationResponses": agent_responses,
                        })
        
        # Rebuild messages: regular + filtered special messages
        final_messages = regular_messages + filtered_special_messages
        
        # Process all messages
        completion = add_completion_msgs(completion, final_messages)
        
        return completion

    # ---------- Single agent streaming helpers ----------

    @staticmethod
    def _update_result_from_data(result: "AgentResult", data: Dict[str, Any], agent_id: str, agent_name: str) -> Optional[str]:
        """Update AgentResult from stream data and return text piece if present."""
        text_piece = data.get("text") if isinstance(data.get("text"), str) else None
        if text_piece:
            result.answer_text += text_piece

        if "artifacts" in data:
            result.artifacts = data["artifacts"]

        if "toolValidationRequests" in data:
            for req in data["toolValidationRequests"]:
                req.setdefault("agentId", agent_id)
                req.setdefault("agentName", agent_name)
            result.tool_validation_requests.extend(data["toolValidationRequests"])

        if "memoryFragment" in data:
            result.memory_fragment = data["memoryFragment"]

        return text_piece

    @staticmethod
    def _emit_agent_finished(pcb, agent_id: str, agent_name: str, result: "AgentResult", status: str) -> None:
        """Emit AGENT_FINISHED event with the given status."""
        event_data = {
            "agentId": agent_id,
            "agentName": agent_name,
            "status": status,
        }
        if status in ("ok", "cancelled"):
            event_data.update({
                "answer": result.answer_text,
                "sources": result.sources,
                "artifacts": result.artifacts,
            })
        pcb({"eventKind": EventKind.AGENT_FINISHED, "eventData": event_data})

    @staticmethod
    def _handle_stream_error(
        result: "AgentResult",
        pcb,
        agent_id: str,
        agent_name: str,
        error_message: str,
        status: str,
        exception: Optional[Exception] = None,
    ) -> None:
        """Handle stream errors by updating result and emitting events."""
        result.success = False
        result.error_message = error_message

        if exception:
            pcb({
                "eventKind": EventKind.AGENT_ERROR,
                "eventData": {
                    "agentId": agent_id,
                    "agentName": agent_name,
                    "message": extract_error_message(str(exception)),
                },
            })

        OrchestratorService._emit_agent_finished(pcb, agent_id, agent_name, result, status)

    # ---------- Single agent streaming ----------
    @staticmethod
    async def _stream_single_agent(
        *,
        query: Dict[str, Any],
        agent_id: str,
        agent_name: str,
        context: Dict[str, Any],
        messages: List[Dict[str, Any]],
        pcb,  # event callback: pcb(payload: Dict[str, Any]) -> None
        artifacts_meta: Dict[str, Any] = {},
        cancel_event: threading.Event,
        stream_timeout_s: Optional[float] = None,
    ) -> AgentResult:
        """
        Streams a single agent to completion, emitting structured events for the frontend.
        Handles both async and sync streaming implementations.
        """
        result = AgentResult(agent_id=agent_id, agent_query=query, agent_name=agent_name)

        pcb({
            "eventKind": EventKind.AGENT_STARTED,
            "eventData": {"agentId": agent_id, "agentName": agent_name, "query": query},
        })

        msgs = deepcopy(messages)
        msgs[-1]["content"] = query

        def handle_data_event(data: Dict[str, Any]):
            """Process a single chunk-shaped dict from the agent stream."""
            normalise_stream_event(
                ev={"chunk": data},
                tcb=lambda x: None,
                pcb=pcb,
                msgs=msgs,
                aid=agent_id,
                aname=agent_name,
                query=query,
                artifacts_meta=artifacts_meta,
            )
            OrchestratorService._update_result_from_data(result, data, agent_id, agent_name)

        def handle_footer_event(chunk, data: Dict[str, Any]):
            """Process a footer chunk from the agent stream."""
            normalise_stream_event(
                ev={"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}},
                tcb=lambda x: None,
                pcb=pcb,
                msgs=msgs,
                aid=agent_id,
                aname=agent_name,
                query=query,
            )
            maybe_sources = data.get("additionalInformation", {}).get("sources", [])
            if maybe_sources:
                result.sources = maybe_sources

        def _stream_sync_iter(sync_iter):
            """Runs in a worker thread via asyncio.to_thread."""
            for chunk in sync_iter:
                if cancel_event.is_set():
                    break
                data = getattr(chunk, "data", None) or chunk
                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    handle_footer_event(chunk, data)
                else:
                    handle_data_event(data)

        try:
            completion = OrchestratorService.create_agent_completion(
                agent_id=agent_id,
                query=query,
                messages=messages,
            )
            completion.with_context(context)

            logger.info(
                "Agent Hub, direct agent call, id=[%s], name=[%s], cq=%s, settings=%s",
                agent_id,
                agent_name,
                getattr(completion, "cq", "unknown"),
                getattr(completion, "settings", "unknown"),
            )

            await OrchestratorService._execute_agent_stream(
                completion, _stream_sync_iter, stream_timeout_s
            )

            status = "cancelled" if cancel_event.is_set() else "ok"
            OrchestratorService._emit_agent_finished(pcb, agent_id, agent_name, result, status)

        except asyncio.TimeoutError:
            OrchestratorService._handle_stream_error(
                result, pcb, agent_id, agent_name,
                "Timeout while streaming the agent response", "timeout"
            )

        except Exception as e:
            OrchestratorService._handle_stream_error(
                result, pcb, agent_id, agent_name,
                "Error while processing your request", "Failed to get an answer", e
            )

        return result

    @staticmethod
    async def _execute_agent_stream(completion, stream_handler, stream_timeout_s: Optional[float]) -> None:
        """Execute the agent stream with optional timeout."""
        async def _do_stream():
            stream_obj = completion.execute_streamed()
            await asyncio.to_thread(stream_handler, stream_obj)

        if stream_timeout_s:
            await asyncio.wait_for(_do_stream(), timeout=stream_timeout_s)
        else:
            await _do_stream()

    # ---------- Multi-agent orchestration helpers ----------

    @staticmethod
    def _collect_agent_results(
        results: List["AgentResult"],
    ) -> tuple[List[Dict[str, Any]], Dict[str, Any]]:
        """Collect tool validation requests and memory fragments from agent results."""
        all_tool_validation_requests = []
        memory_fragments_by_agent = []

        for r in results:
            if r.tool_validation_requests:
                all_tool_validation_requests.extend(r.tool_validation_requests)

            if r.memory_fragment:
                wrapped_fragment = {
                    "role": "memoryFragment",
                    "memoryFragment": r.memory_fragment,
                    "memoryFragmentTarget": {
                        "agentName": r.agent_name,
                        "agentId": r.agent_id,
                    },
                }
                memory_fragments_by_agent.append(wrapped_fragment)

        combined_memory_fragment = {"messages": memory_fragments_by_agent} if memory_fragments_by_agent else {}
        return all_tool_validation_requests, combined_memory_fragment

    @staticmethod
    def _extract_text_from_chunk(resp: Dict[str, Any]) -> Optional[str]:
        """Extract text piece from a response chunk if present."""
        if not isinstance(resp, dict):
            return None
        chunk = resp.get("chunk")
        if isinstance(chunk, dict) and "text" in chunk:
            return chunk["text"]
        return None

    @staticmethod
    def _extract_trajectory_from_footer(resp: Dict[str, Any]) -> Optional[Any]:
        """Extract trajectory from a response footer if present."""
        if not isinstance(resp, dict) or "footer" not in resp:
            return None
        add_info = resp["footer"].get("additionalInformation", {}) or {}
        return add_info.get("trajectory")

    @staticmethod
    def _extract_trace_from_response(resp: Dict[str, Any]) -> Optional[Any]:
        """Extract trace from a trace_ready response if present."""
        if isinstance(resp, dict) and "trace_ready" in resp:
            return resp["trace_ready"].get("trace")
        return None

    @staticmethod
    def _process_synthesis_response(
        resp: Dict[str, Any],
        tcb,
        pcb,
        messages: List[Dict[str, Any]],
        stream_id: str,
    ) -> tuple[str, Optional[Any], Optional[Any]]:
        """Process a single synthesis response and extract relevant data."""
        normalise_stream_event(ev=resp, tcb=tcb, pcb=pcb, msgs=messages, aname="Agent Hub", stream_id=stream_id)

        text_piece = OrchestratorService._extract_text_from_chunk(resp) or ""
        trajectory = OrchestratorService._extract_trajectory_from_footer(resp)
        synthesis_trace = OrchestratorService._extract_trace_from_response(resp)

        return text_piece, trajectory, synthesis_trace

    @staticmethod
    def _finalize_synthesis_span(synthesis_span, synthesis_trace) -> None:
        """Finalize the synthesis span with trace data."""
        if not synthesis_span:
            return
        if synthesis_trace:
            synthesis_span.append_trace(synthesis_trace)
        synthesis_span.end(int(time.time() * 1000))

    # ---------- Multi-agent orchestration ----------

    @staticmethod
    async def stream_multiple_agents_async(
        llm_id: str,
        agents_queries: Dict[str, Any],
        agents: list[dict],
        context: Dict[str, Any],
        messages: List[Dict[str, Any]],
        tcb,  # text callback for FINAL synthesis only (unchanged)
        pcb,  # event callback for UI
        artifacts_meta,  # type: Dict[str, Any]
        cancel_event: threading.Event,
        *,
        stream_timeout_s: Optional[float] = None,
        trace=None,  # Add trace parameter to accept main trace
    ) -> tuple[str, List[Dict[str, Any]], Dict[str, Any]]:
        """
        Run multiple agents concurrently via asyncio, stream their outputs as isolated event streams,
        then synthesize a final answer (also streamed).
        
        The messages list may contain special message types (memoryFragment, toolValidation*)
        which will be processed by create_agent_completion.
        
        Returns: (final_text, tool_validation_requests, memory_fragment)
        """
        tasks = [
            OrchestratorService._stream_single_agent(
                query=agents_queries.get(agent.get("id")) or agents_queries.get(agent.get("uaid")),
                agent_id=agent.get("id"),
                agent_name=agent.get("name"),
                messages=messages,
                context=context,
                pcb=pcb,
                artifacts_meta=artifacts_meta,
                cancel_event=cancel_event,
                stream_timeout_s=stream_timeout_s,
            )
            for agent in agents
            if agent.get("id") in agents_queries or agent.get("uaid", "None") in agents_queries
        ]

        results: List[AgentResult] = []
        for coro in asyncio.as_completed(tasks):
            if cancel_event.is_set():
                break
            results.append(await coro)

        generated_answers = [r.to_generated_answer() for r in results]
        all_tool_validation_requests, combined_memory_fragment = OrchestratorService._collect_agent_results(results)

        if cancel_event.is_set() or all_tool_validation_requests:
            logger.info(f"Orchestrator: {len(results)} agents completed. Tool validation requests count: {len(all_tool_validation_requests)}")
            if all_tool_validation_requests:
                logger.info("Tool validation requests detected in sub-agents. Skipping synthesis.")
            return "", all_tool_validation_requests, combined_memory_fragment

        # --- Synthesis phase ---
        return await OrchestratorService._run_synthesis_phase(
            llm_id=llm_id,
            agents=agents,
            generated_answers=generated_answers,
            messages=messages,
            tcb=tcb,
            pcb=pcb,
            cancel_event=cancel_event,
            trace=trace,
            all_tool_validation_requests=all_tool_validation_requests,
            combined_memory_fragment=combined_memory_fragment,
        )

    @staticmethod
    async def _run_synthesis_phase(
        llm_id: str,
        agents: list[dict],
        generated_answers: List[Dict[str, Any]],
        messages: List[Dict[str, Any]],
        tcb,
        pcb,
        cancel_event: threading.Event,
        trace,
        all_tool_validation_requests: List[Dict[str, Any]],
        combined_memory_fragment: Dict[str, Any],
    ) -> tuple[str, List[Dict[str, Any]], Dict[str, Any]]:
        """Run the synthesis phase to generate the final answer."""
        pcb({"eventKind": EventKind.SYNTHESIZING_STARTED, "eventData": {}})

        final = ""
        synthesis_trace = None
        stream_id = str(uuid.uuid4())

        synthesis_span = None
        if trace:
            synthesis_span = trace.subspan("synthesize_answers")
            synthesis_span.begin(int(time.time() * 1000))

        try:
            answer_stream = OrchestratorService.generate_final_answer(
                llm_id=llm_id,
                agents=agents,
                agents_queries_answers=generated_answers,
                messages=messages,
            )
            for resp in answer_stream:
                if cancel_event.is_set():
                    break
                text_piece, trajectory, trace_data = OrchestratorService._process_synthesis_response(
                    resp, tcb, pcb, messages, stream_id
                )
                final += text_piece
                if trace_data:
                    synthesis_trace = trace_data
        finally:
            OrchestratorService._finalize_synthesis_span(synthesis_span, synthesis_trace)

        return final, all_tool_validation_requests, combined_memory_fragment

    # ---------- Sync wrapper (safe in any context) ----------

    @staticmethod
    def stream_multiple_agents(
        llm_id,
        sel_agents,
        messages,
        context,
        tcb,
        pcb,
        artifacts_meta,
        cancel_event: threading.Event,
        store=None,
        *,
        stream_timeout_s: Optional[float] = None,
        trace=None,  # Add trace parameter
    ) -> tuple[str, List[Dict[str, Any]], Dict[str, Any]]:
        """
        Synchronous entrypoint that runs the asyncio pipeline.
        Returns: (final_text, tool_validation_requests, memory_fragment)

        If already inside an event loop (e.g., FastAPI request handler), we spin up
        a fresh loop in a background thread to avoid 'event loop is running' errors.
        """
        # TODO: propagate contextUpsert from multi-agent streaming footer.
        sel_ids = [a["agentId"] for a in sel_agents if "agentId" in a]
        agents_obj = get_selected_agents_as_objs(store, sel_ids)
        agents_queries = {a["agentId"]: a.get("query", {}) for a in sel_agents if "agentId" in a}

        async def _runner():
            return await OrchestratorService.stream_multiple_agents_async(
                llm_id=llm_id,
                agents_queries=agents_queries,
                agents=agents_obj,
                messages=messages,
                context=context,
                tcb=tcb,
                pcb=pcb,
                artifacts_meta=artifacts_meta,
                cancel_event=cancel_event,
                stream_timeout_s=stream_timeout_s,
                trace=trace,  # Pass trace through
            )

        # If no loop is running in this thread, use asyncio.run
        try:
            asyncio.get_running_loop()
            loop_running_here = True
        except RuntimeError:
            loop_running_here = False

        if not loop_running_here:
            return asyncio.run(_runner())

        # A loop is already running in this thread -> start a new loop in a worker thread
        result_box = {"value": ("", [], {})}

        def _thread_target():
            new_loop = asyncio.new_event_loop()
            try:
                asyncio.set_event_loop(new_loop)
                result_box["value"] = new_loop.run_until_complete(_runner())
            finally:
                new_loop.close()

        th = threading.Thread(target=_thread_target, daemon=True)
        th.start()
        th.join()
        return result_box["value"]

    # ---------- Final synthesis (unchanged, streams sync) ----------

    @staticmethod
    def generate_final_answer(llm_id, agents, agents_queries_answers, messages):
        client = dataiku.api_client()
        project = client.get_default_project()
        comp = project.get_llm(llm_id).new_completion()

        system_prompt = f"""
        # Role and Guidelines
        You are an assistant that synthesizes information from different agents to provide a final answer to the user query.
        Your role is to read the initial user query and any answers generated by different agents to answer parts of the query or all and provide a final answer.

        Your responsibilities:
        - Read the user query and read answers from agents (based on sub-queries).
        - Understand which agent provided which answer (via `agent_id`).
        - Each agent has its own scope and expertise. Use the answers provided by them as they are and rely on their answers to provide user with full answer. Do not make up your own.
        - Synthesize the information from the answers provided by the different agents to provide a final answer to the user query.
        - Do NOT alter the answers provided by external agents.
        - If you need to provide a different answer, make an additional one and clearly mention it's your own.

        Use the agent metadata to help you understand each agents's scope and its answers.
        Given the initial user query, any possible generated answers, and the context of the conversation, provide a final answer to the initial user query.
        Don't change the answer generated by the agents. If you need to provide a different answer, provide it in a new answer and mention that you are providing a different answer.
        Here is additional metadata that might help you about the available external agents:
        {
            json.dumps(
                list(
                    map(
                        lambda a: {
                            "id": a.get("id"),
                            "description": a.get("tool_agent_description"),
                        },
                        agents,
                    )
                ),
                indent=2,
            )
        }
        """

        gen_queries_answers_json = (
            [
                json.dumps(
                    {
                        "agent_query": qa.get("agent_query", "no query provided"),
                        "agent_answer": qa.get("agent_answer", "no answer provided"),
                        "agent_id": qa.get("agent_id", ""),
                    }
                )
                .replace("{", "{{")
                .replace("}", "}}")
                for qa in agents_queries_answers
            ]
            if agents_queries_answers
            else []
        )

        final_prompt = r"""
         {system_prompt}
         Generated sub queries and answers by agents based on user query:
         {agents_queries_answers}
         """.format(
            system_prompt=system_prompt,
            agents_queries_answers=gen_queries_answers_json,
        )

        comp.with_message(final_prompt, role="system")
        
        # Filter messages to only include regular conversation messages for synthesis
        # Synthesis doesn't need tool validation or memory fragment context
        regular_messages_only = [msg for msg in messages if msg.get("role") not in ["memoryFragment", "toolValidationRequests", "toolValidationResponses"]]
        
        comp = add_history_to_completion(completion=comp, messages=regular_messages_only)
        user_prompt = (
            "Based on the system instructions and user query, provide the final answer to the initial user query."
        )
        comp.with_message(user_prompt, role="user")
        logger.info(f"Synthesis: Starting to execute streamed completion")

        # Dataiku's execute_streamed() is a synchronous iterator
        chunk_count = 0
        for chunk in comp.execute_streamed():
            chunk_count += 1
            if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                yield {"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}}
                trace_data = getattr(chunk, "trace", None)
                yield {"trace_ready": {"trace": trace_data if trace_data else {}}}
            else:
                yield {"chunk": chunk.data}
        
        logger.info(f"Synthesis: Completed streaming {chunk_count} chunks")
