##############################################################################
# 1 · Imports
##############################################################################
from __future__ import annotations

import json
from functools import partial
from typing import Any, Dict, List, Optional, Union

import dataiku
from backend.config import get_config
from backend.models.events import EventKind
from backend.utils.conv_utils import normalise_stream_event
from backend.utils.logging_utils import get_logger
from dataiku.langchain.dku_tracer import LangchainToDKUTracer
from dataiku.llm.python import BaseLLM
from dataiku.llm.python.tools_using import DKUToolNode, format_multipart_messages
from dataikuapi.dss.llm import (
    DSSLLMStreamedCompletionFooter,
)
from langchain_core.tools import StructuredTool
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel, Field

logger = get_logger(__name__)


##############################################################################
# 2 · Schema models
##############################################################################
class DSSToolAgent(BaseModel):
    """Metadata that defines a DSS-hosted LLM agent to wrap as a tool."""

    dss_agent_id: str = Field(..., description="Identifier for the DSS LLM agent")
    agent_system_instructions: Optional[str] = Field(
        None, description="System prompt to prepend when querying the agent"
    )
    tool_agent_description: str = Field(
        ..., description="Description shown to the caller LLM so it knows WHEN to use this tool"
    )
    agent_name: Optional[str] = Field(default=None, description="Friendly name for the agent")


class _Query(BaseModel):
    """Single-field schema – what the end-user will supply to the tool."""

    query: str = Field(..., description="User query forwarded to the DSS agent")


##############################################################################
# 3 · Helper that returns a **StructuredTool** for ONE DSS agent
##############################################################################
def build_dss_agent_tool(
    agent: DSSToolAgent, context: dict, artifacts_meta: dict, tracer=None, pcb=None
) -> StructuredTool:
    """
    Build a StructuredTool for one DSS agent, with its ID & system prompt
    already bound via partial().
    """
    from langchain_core.callbacks.manager import CallbackManagerForToolRun

    def _call_agent(
        query: str,
        *,
        dss_agent_id: str,
        system_instructions: Optional[str],
        tracer=None,
        friendly_name: Optional[str] = None,
        tool_name: Optional[str] = None,
        artifacts_meta: Dict[str, Any] = {},
        run_manager: Optional[CallbackManagerForToolRun] = None,
        pcb=None,
        agent_name: Optional[str] = None,
    ) -> tuple[str, dict]:
        client = dataiku.api_client()
        project_key, short = dss_agent_id.split(":", 1)
        llm = client.get_project(project_key).get_llm(short)
        comp = llm.new_completion()
        if system_instructions:
            comp = comp.with_message(system_instructions, role="system")
        comp = comp.with_message(query, role="user")
        comp.with_context(context)

        logger.info(
            "Calling agent as a tool:\nAgent id=[%s]\nAgent name=[%s]\nCompletion Query=%s\nCompletion Settings=%s\n",
            dss_agent_id,
            agent_name or tool_name or "unknown",
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        if pcb:
            pcb(
                {
                    "eventKind": EventKind.AGENT_CALLING_AGENT,
                    "eventData": {"agentName": "Agent Hub", "agentAsToolName": agent_name or tool_name, "query": query},
                }
            )
        # Before executing, find the current tool's run_id and update its name in the trace
        current_tool_run_id = None
        if tracer and hasattr(tracer, "run_map"):
            try:
                this_tool_run = None
                # Find active tool run with this tool_name
                for run in tracer.run_map.values():
                    if (
                        getattr(run, "run_type", None) == "tool"
                        and getattr(run, "name", None) == tool_name
                        and getattr(run, "end_time", None) is None
                    ):
                        this_tool_run = run
                        break

                if this_tool_run:
                    current_tool_run_id = str(this_tool_run.id)
                    if friendly_name and getattr(tracer, "run_id_to_span_map", None):
                        span_map = tracer.run_id_to_span_map
                        if current_tool_run_id in span_map:
                            try:
                                tool_span_builder = span_map[current_tool_run_id]
                                tool_span_builder.span["name"] = friendly_name
                            except Exception as e:
                                logger.warning("Could not update trace span name for tool %s: %s", tool_name, e)
                else:
                    logger.warning("Could not find active run for tool named '%s'", tool_name)
            except Exception as e:
                logger.warning("Tracer inspection failed for tool %s: %s", tool_name, e)

        # --- 2) Execute the agent call with streaming ---
        buf: list[str] = []
        sources: Dict[str, Any] = {}
        footer_data: Dict[str, Any] = {}
        artifacts = []

        # Helper: forward custom events to the client if present
        def _emit_event(event_type: str, payload: Dict[str, Any]) -> None:
            if run_manager:
                try:
                    run_manager.on_custom_event({"type": event_type, **payload})
                except Exception as e:
                    logger.debug("Failed to forward custom event %s: %s", event_type, e)

        for chunk in comp.execute_streamed():
            # Case A: footer with metadata (e.g., sources, trace info)
            if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                footer_data = getattr(chunk, "data", {}) or {}
                ai = footer_data.get("additionalInformation", {}) or {}
                sources = ai.get("sources", {}) or footer_data.get("sources", {}) or {}

            # Case B: agent-emitted event chunks
            elif chunk.type == "event":
                event_data = chunk.data.get("eventData") or {}
                if not isinstance(event_data, dict):
                    event_data = {}
                payload = {
                    "eventKind": chunk.event_kind,
                    "eventData": {**event_data, "agentName": agent.agent_name},
                }
                if pcb:
                    pcb(payload)
                _emit_event("agent_event", payload)

            # Case C: token/text chunk
            elif chunk.type == "content":
                text = chunk.text
                if text:
                    buf.append(text)
                    if run_manager:
                        try:
                            run_manager.on_text(text)
                        except Exception as e:
                            logger.debug("on_text failed: %s", e)
                if "artifacts" in chunk.data:
                    artifacts = chunk.data["artifacts"]
                    payload = {
                        "chunk": {
                            "type": "content",
                            "artifacts": artifacts,
                            "eventData": {"agentName": agent_name or tool_name, "agentId": dss_agent_id},
                        }
                    }
                    normalise_stream_event(
                        ev=payload,
                        tcb=lambda x: None,
                        pcb=pcb,
                        msgs=[
                            {"role": "system", "content": system_instructions},
                            {
                                "role": "user",
                                "content": query,
                            },
                        ],
                        aid=dss_agent_id,
                        artifacts_meta=artifacts_meta,
                        trace=None,
                        aname=agent_name or tool_name,
                        query=query,
                    )
                    # if pcb:
                    #     pcb({"eventKind": "artifacts", "eventData": artifacts})

            # Case D: unknown chunk shape -> forward as diagnostic
            else:
                _emit_event("agent_unknown_chunk", {"repr": repr(chunk)})

        final_text = "".join(buf)

        # --- 3) attach sub-trace to the current tool span (if provided) ---
        try:
            if tracer and current_tool_run_id and footer_data:
                maybe_trace = footer_data.get("trace") or footer_data.get("agentTrace")
                if maybe_trace and getattr(tracer, "run_id_to_span_map", None):
                    span_map = tracer.run_id_to_span_map
                    if current_tool_run_id in span_map:
                        tool_span_builder = span_map[current_tool_run_id]
                        tool_span_builder.span.setdefault("children", [])
                        tool_span_builder.span["children"].append(maybe_trace)
                    else:
                        logger.debug(
                            "No span found for run_id %s to attach sub-trace.",
                            current_tool_run_id,
                        )
        except Exception as e:
            logger.debug("Failed to attach sub-trace for agent %s: %s", dss_agent_id, e)
        # --- 4) Final tool return (content, artifact) ---
        return final_text, {
            "output": final_text,
            "sources": sources,
            "footer": footer_data,
            "artifacts": artifacts,
        }

    tool_name = f"ask_{agent.dss_agent_id.replace(':', '_')}"
    tool_name = tool_name[:64]
    bound_fn = partial(
        _call_agent,
        dss_agent_id=agent.dss_agent_id,
        system_instructions=agent.agent_system_instructions,
        tracer=tracer,
        friendly_name=agent.agent_name or "Agent",
        tool_name=tool_name,
        artifacts_meta=artifacts_meta,
        pcb=pcb,
        agent_name=agent.agent_name or agent.dss_agent_id,
    )

    tool = StructuredTool.from_function(
        func=bound_fn,
        name=tool_name,
        description=agent.tool_agent_description,
        args_schema=_Query,  # your Pydantic schema
        return_direct=False,
        response_format="content_and_artifact",
    )
    return tool


##############################################################################
# 4 · Your LLM wrapper
##############################################################################
class AgentConnect(BaseLLM):
    """
    self.agents may now contain EITHER:
      • DSSToolAgent            (enterprise agent or user agent w/o tools)
      • StructuredTool          (user agent WITH its own sub-tools)
    """

    def __init__(
        self,
        base_model: str,
        agents: List[Union[DSSToolAgent, StructuredTool]] | None = None,
    ):
        self.agents = agents or []
        self.base_model = base_model
        self.project = dataiku.api_client().get_default_project()
        self.config = get_config()

    # --------------------------------------------------------------------- #
    # Build the tool list the first time we see a new thread
    # --------------------------------------------------------------------- #
    def _build_tool_node(self, context, artifacts_meta, tracer=None, pcb=None) -> DKUToolNode:
        structured_tools: list[StructuredTool] = []
        for item in self.agents:
            if isinstance(item, DSSToolAgent):
                st = build_dss_agent_tool(item, context, artifacts_meta, tracer, pcb)
                structured_tools.append(st)

            elif isinstance(item, StructuredTool):
                if hasattr(item, "func") and isinstance(item.func, partial):
                    # The 'func' is a partial where the tracer was pre-bound (likely as None).
                    # We can now overwrite it with the real tracer.
                    item.func.keywords["tracer"] = tracer
                    logger.debug(f"Injected tracer into UserAgent tool: {item.name}")

                structured_tools.append(item)

        tool_node = DKUToolNode(tools=structured_tools)
        tool_node._sources = []
        return tool_node

    # --------------------------------------------------------------------- #
    # Async streaming entry-point used by Dataiku
    # --------------------------------------------------------------------- #
    async def aprocess_stream(self, query, settings, artifacts_meta, trace, pcb):
        tracer = LangchainToDKUTracer(dku_trace=trace)
        tool_node = self._build_tool_node(query.get("context"), artifacts_meta, tracer, pcb)

        # Build the *controller* agent (ReAct)
        completion_settings = self.config.get("completionSettings", {})
        controller_llm = self.project.get_llm(self.base_model).as_langchain_chat_model(
            completion_settings=completion_settings
        )

        messages_without_parts = format_multipart_messages(query["messages"])
        content_messages = []
        for msg in messages_without_parts:
            if isinstance(msg, dict) and "role" in msg and "content" in msg:
                content_messages.append({"role": msg["role"], "content": msg["content"]})

        # Log a debug message with the safe message contents
        logger.info(
            "Agent Hub - create_react_agent to run messages: %s",
            content_messages,
        )

        graph = create_react_agent(
            controller_llm,
            tool_node,
            state_modifier=self.config.get("systemPromptAppend"),
            debug=False,
        )

        async for event in graph.astream_events(
            {"messages": content_messages},
            {"configurable": {"thread_id": "thread-1"}, "callbacks": [tracer]},
            stream_mode="messages",
            version="v2",
        ):
            try:
                kind = event["event"]
                if kind == "on_chat_model_stream":
                    content = event["data"]["chunk"].content
                    if content:
                        yield {"chunk": {"text": content}}
                elif kind == "on_chain_start":
                    if "name" in event and event["name"] == "agent":
                        yield {
                            "chunk": {
                                "type": "event",
                                "eventKind": "AGENT_THINKING",
                                "eventData": {"agentName": "Agent Hub"},
                            }
                        }
            except Exception as e:
                print(f"Exception {e}")

        all_sources = tool_node._sources or []
        yield {"footer": {"additionalInformation": {"sources": all_sources}}}
        # if tool_node._artifacts:
        #     yield {
        #         "chunk": {"type": "content", "artifacts": tool_node._artifacts, "eventData": {"agentName": "Agent Hub"}}
        #     }
