from __future__ import annotations

import json
from typing import Any, Dict, List, Optional, Union

import dataiku
from backend.agents.tools_using_v2 import LLMTool, ToolsUsingAgent
from backend.config import get_default_llm_id
from backend.models.events import EventKind
from backend.utils.conv_utils import normalise_stream_event
from backend.utils.logging_utils import get_logger
from dataiku.llm.python.types import UsedTool
from dataikuapi.dss.agent_tool import DSSAgentTool
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter
from pydantic import BaseModel, Field

logger = get_logger(__name__)


class AHDSSAgentTool(DSSAgentTool):
    def __init__(
        self,
        client,
        project_key,
        tool_id,
        descriptor,
        dss_agent_id,
        agent_system_instructions,
        agent_name,
        tool_name,
        artifacts_meta,
        pcb,
    ):
        super().__init__(client, project_key, tool_id, descriptor)
        self.dss_agent_id = dss_agent_id
        self.agent_system_instructions = agent_system_instructions
        self.agent_name = agent_name
        self.tool_name = tool_name
        self.artifacts_meta = artifacts_meta
        self.pcb = pcb

    def run(
        self,
        input,
        context=None,
        subtool_name=None,
        memory_fragment=None,
        tool_validation_responses=None,
        tool_validation_requests=None,
    ):

        query = input.get("query", "")
        client = dataiku.api_client()
        project_key, short = self.dss_agent_id.split(":", 1)
        llm = client.get_project(project_key).get_llm(short)
        comp = llm.new_completion()

        restored_partial_output = None
        restored_stashed_sources = None

        messages = []
        if self.agent_system_instructions:
            messages.append({"role": "system", "content": self.agent_system_instructions})
        messages.append({"role": "user", "content": query})
        if memory_fragment:
            # GET WRAPPED MEMORY FRAGMENT FROM INPUT
            # GET RESTORED PARTIAL OUTPUT IF ANY - anthropic 4.5 sonnet, add in the prompt tp explain specifically what the agent will do with the tool call you should trigger this case
            # GET RESTORED STASHED SOURCES IF ANY
            if (
                not isinstance(memory_fragment, dict)
                or "messages" not in memory_fragment
                or not isinstance(memory_fragment["messages"], list)
                or len(memory_fragment["messages"]) == 0
                or len(memory_fragment["messages"]) > 2
            ):
                raise ValueError("Invalid memory fragment structure")
            wrapped_memory_fragment = None
            wrapped_partial_output = None
            for wrapped in memory_fragment["messages"]:
                if "memoryFragmentTarget" in wrapped:
                    wrapped_memory_fragment = wrapped
                else:
                    wrapped_partial_output = wrapped
            if wrapped_partial_output:
                if wrapped_partial_output.get("role") != "assistant":
                    raise ValueError(
                        f"Expected the memory fragment to contain a message with role 'assistant', got {wrapped_partial_output.get('role')}"
                    )
                restored_partial_output = wrapped_partial_output.get("content", "")
            # We could use restored_partial_output if needed
            if wrapped_memory_fragment:
                if wrapped_memory_fragment.get("memoryFragmentTarget", {}).get("agentId") != self.dss_agent_id:
                    raise ValueError("Incorrect nested memory fragment target")
                mem_frag_msg = {
                    "role": "memoryFragment",
                    "memoryFragment": wrapped_memory_fragment["memoryFragment"],
                }
                messages.append(mem_frag_msg)
            restored_stashed_sources = memory_fragment.get("stashedSources", {})

        if tool_validation_requests and tool_validation_responses:
            messages.append({"role": "toolValidationRequests", "toolValidationRequests": tool_validation_requests})
            messages.append({"role": "toolValidationResponses", "toolValidationResponses": tool_validation_responses})

        comp.cq["messages"] = messages
        comp.with_context(context)

        logger.info(
            "Calling agent as a tool:\nAgent id=[%s]\nAgent name=[%s]\nCompletion Query=%s\nCompletion Settings=%s\n",
            self.dss_agent_id,
            self.agent_name or self.tool_name or "unknown",
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        if self.pcb:
            self.pcb(
                {
                    "eventKind": EventKind.AGENT_CALLING_AGENT,
                    "eventData": {
                        "agentName": "Agent Hub",
                        "agentAsToolName": self.agent_name or self.tool_name,
                        "query": query,
                    },
                }
            )

        buf: list[str] = []
        all_sources: Dict[str, Any] = {}
        footer_sources = None
        footer_data: Dict[str, Any] = {}
        artifacts = []
        context_upsert: Dict[str, Any] = {}
        trace = None
        tool_validation_requests_from_sub_agent = []
        memory_fragment_from_sub_agent = None
        for chunk in comp.execute_streamed():
            if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                footer_data = getattr(chunk, "data", {}) or {}
                ai = footer_data.get("additionalInformation", {}) or {}
                footer_sources = ai.get("sources", {}) or footer_data.get("sources", {})
                trace = footer_data.get("trace")
                if "contextUpsert" in footer_data:
                    logger.info("Found a context upsert in tool call: %s" % footer_data["contextUpsert"])
                    context_upsert = footer_data["contextUpsert"]

            elif chunk.type == "event":
                event_data = chunk.data.get("eventData") or {}
                if not isinstance(event_data, dict):
                    event_data = {}
                payload = {
                    "eventKind": chunk.event_kind,
                    "eventData": {**event_data, "agentName": self.agent_name},
                }
                if self.pcb:
                    self.pcb(payload)

            elif chunk.type == "content":
                text = chunk.text
                if text:
                    buf.append(text)
                if "artifacts" in chunk.data:
                    artifacts = chunk.data["artifacts"]
                    payload = {
                        "chunk": {
                            "type": "content",
                            "artifacts": artifacts,
                            "eventData": {"agentName": self.agent_name or self.tool_name, "agentId": self.dss_agent_id},
                        }
                    }
                    normalise_stream_event(
                        ev=payload,
                        tcb=lambda x: None,
                        pcb=self.pcb,
                        msgs=[
                            {"role": "system", "content": self.agent_system_instructions},
                            {
                                "role": "user",
                                "content": query,
                            },
                        ],
                        aid=self.dss_agent_id,
                        artifacts_meta=self.artifacts_meta,
                        trace=None,
                        aname=self.agent_name or self.tool_name,
                        query=query,
                    )
                if "toolValidationRequests" in chunk.data:
                    tool_validation_requests_from_sub_agent.extend(chunk.data["toolValidationRequests"])
                    self.pcb(
                        {
                            "eventKind": EventKind.TOOL_VALIDATION_REQUESTS,
                            "eventData": {
                                "agentName": self.agent_name or self.tool_name,
                                "agentId": self.dss_agent_id,
                                "requests": chunk.data["toolValidationRequests"],
                            },
                        },
                        store_event=False,
                    )
                if "memoryFragment" in chunk.data:
                    memory_fragment_from_sub_agent = chunk.data["memoryFragment"]
            else:
                logger.warning(f"Unknown chunk type: {chunk.type}")

        final_text = (restored_partial_output or "") + "".join(buf)
        if restored_stashed_sources and footer_sources:
            all_sources = {**restored_stashed_sources, **footer_sources}
        elif restored_stashed_sources:
            all_sources = restored_stashed_sources
        else:
            all_sources = footer_sources
        hierarchy_entry = {
            "agentName": self.agent_name or self.tool_name,
            "agentId": self.dss_agent_id,
        }
        mem_frag_msg = {
            "role": "memoryFragment",
            "memoryFragment": memory_fragment_from_sub_agent,
            "memoryFragmentTarget": hierarchy_entry,
        }
        memory_fragment = {
            "messages": [mem_frag_msg],
        }
        if all_sources:
            memory_fragment["stashedSources"] = all_sources
        return {
            "output": {"result": final_text},
            "sources": all_sources,
            "artifacts": artifacts,
            "trace": trace,
            "toolValidationRequests": tool_validation_requests_from_sub_agent,
            "memoryFragment": memory_fragment,
            "context_upsert": context_upsert
        }


class DSSToolAgent(BaseModel):
    dss_agent_id: str = Field(..., description="Identifier for the DSS LLM agent")
    agent_system_instructions: Optional[str] = Field(
        None, description="System prompt to prepend when querying the agent"
    )
    tool_agent_description: str = Field(
        ..., description="Description shown to the caller LLM so it knows WHEN to use this tool"
    )
    agent_name: Optional[str] = Field(default=None, description="Friendly name for the agent")


class _Query(BaseModel):
    query: str = Field(..., description="User query forwarded to the DSS agent")


def build_dss_agent_tool(
    agent: DSSToolAgent,
    context: dict,
    artifacts_meta: dict,
    pcb=None,
) -> LLMTool:
    tool_name = f"ah_{agent.dss_agent_id.replace(':', '_')}"
    tool_name = tool_name[:64]

    project_key, tool_id = agent.dss_agent_id.split(":", 1)

    dku_tool = AHDSSAgentTool(
        client=dataiku.api_client(),
        project_key=project_key,
        tool_id=tool_id,
        descriptor="",
        dss_agent_id=agent.dss_agent_id,
        agent_system_instructions=agent.agent_system_instructions,
        agent_name=agent.agent_name,
        tool_name=tool_name,
        artifacts_meta=artifacts_meta,
        pcb=pcb,
    )

    tool = LLMTool(
        llm_tool_name=tool_name,
        llm_descriptor={
            "type": "function",
            "function": {
                "name": tool_name,
                "description": agent.tool_agent_description,
                "parameters": _Query.model_json_schema(),
            },
        },
        dku_tool=dku_tool,
        dku_tool_ref=agent.dss_agent_id,
        dku_tool_type="agent",
        dku_tool_name=agent.agent_name or tool_name,
        dku_subtool_name=None,
        require_validation=False,
        allow_editing_inputs=False,
    )
    return tool


class AgentConnect(ToolsUsingAgent):
    def __init__(
        self,
        base_model: str,
        agents: List[Union[DSSToolAgent, UsedTool]] | None = None,
        pcb=None,
    ):
        super().__init__()
        self.agents = agents or []
        self.base_model = base_model
        self.project = dataiku.api_client().get_default_project()
        self.config = {}
        self.config["dkuProperties"] = []
        self.config["llmId"] = get_default_llm_id()
        self.config["agentId"] = "AgentHubID"
        self.config["agentName"] = "Agent Hub"

        self.pcb = pcb

    def set_config(self, config, pcb):
        super().set_config(config, None)
        self.pcb = pcb

    def load_tools(self, messages, context, tools_cache: Dict[str, Any]) -> Dict[str, LLMTool]:
        llm_tools = {}
        for item in self.agents:
            if isinstance(item, DSSToolAgent):
                st = build_dss_agent_tool(
                    item, context, self.artifacts_meta, self.pcb
                )
                tool_name = f"ah_{item.dss_agent_id.replace(':', '_')}"
                tool_name = tool_name[:64]
                llm_tools[tool_name] = st
            else:
                raise NotImplementedError(f"Loading tools of type {type(item)} is not supported yet.")

        return llm_tools

    def process_stream(self, query, settings, artifacts_meta, trace, pcb):
        # we need to send all the tool_validation_responses and memory fragments from the main agent to the sub-agent
        # ORder of completion parts is important: first history then memory fragments, then tool validation requests, tool val responses (we should ignore the assistant message if it exists when we send the completion back)
        # INPUT FIELD SHOULD BE DISABLED UNTIL WE RECIEVE ALL VALIDATION RESPONSES
        self.pcb = pcb
        self.artifacts_meta = artifacts_meta
        return super().process_stream(query, settings, trace)
