import json
import logging
import time
from typing import List

from dataiku.llm.python.tools_using_2 import  LLMTool
from dataiku.llm.python.types import AgentHierarchyEntry, ChatMessage, ToolHierarchyEntry
from dataikuapi.dss.llm_tracing import SpanBuilder
from . import StreamableLLMBlock


logger = logging.getLogger("dku.agents.blocks_graph")

class SingleToolCallHandler(StreamableLLMBlock):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)
        self.iteration_number = 1

    def process_stream(self, trace: SpanBuilder):
        raise NotImplementedError

    def _call_tool(self, tool_call, llm_tool: LLMTool, parent_trace: SpanBuilder):
        logger.info("Calling 1 tool from single tool call block")
    
        # Prepare Hierarchy Entry
        agent_id = self.agent.config.get("agentId", "")
        agent_name = self.agent.config.get("agentName", "")
        agent_level_hierarchy_entry: AgentHierarchyEntry = {
            "type": "AGENT",
            "agentLoopIteration": self.iteration_number,
            "agentId": agent_id,
            "agentName": agent_name
        }

        # Check for pre-tool-call HITL
        if llm_tool.require_validation:
            # TODO @structured-visual-agents Implement HITL here
            raise Exception("Tool call requires human validation. Currently not supported by tool call blocks. Please use a core loop block instead.")
    
        # Add Assistant Message (Request) to history
        if self.block_config.get("outputMode") == "ADD_TO_MESSAGES":
            tool_calls_msg: ChatMessage = {
                "role": "assistant",
                "toolCalls": [tool_call.dku_tool_call]
            }
            self.sequence_context.generated_messages.append(tool_calls_msg)
    
        # Notify UI: Tool Started
        yield {"chunk": {"type": "event", "eventKind": "AGENT_TOOL_START", "eventData": {"toolName": tool_call.name}}}
    
        output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, tool_trace = self._do_run_tool(
            llm_tool, tool_call, self.turn.current_merged_context
        )

        # Check for within-tool-call HITL
        if tool_validation_requests:
            # TODO @structured-visual-agents Implement HITL here
            raise Exception("Tool call requires human validation. Currently not supported by tool call blocks. Please use a core loop block instead.")

        tool_ref = llm_tool.dku_tool_ref
        tool_name = llm_tool.dku_tool_name
    
        tool_level_hierarchy_entry: ToolHierarchyEntry = {
            "type": "TOOL",
            "toolRef": tool_ref,
            "toolName": tool_name,
            "toolCallId": tool_call.id
        }
    
        # Inject Hierarchy into Artifacts
        for artifact in artifacts:
            hierarchy: List = artifact.setdefault("hierarchy", [])
            hierarchy.insert(0, tool_level_hierarchy_entry)
            hierarchy.insert(0, agent_level_hierarchy_entry)
    
        yield {"chunk": {"type": "content", "artifacts": artifacts}}

        # Inject Hierarchy into Sources
        for source in sources:
            hierarchy: List = source.setdefault("hierarchy", [])
            hierarchy.insert(0, tool_level_hierarchy_entry)
            hierarchy.insert(0, agent_level_hierarchy_entry)

        # Store sources, to emit later in the streaming footer
        logger.info("Storing sources %s" % sources)
        self.sequence_context.sources.extend(sources)

        if tool_trace:
            parent_trace.append_trace(tool_trace)
    
        # Save the output as requested
        # TODO at some point we should be able to handle parts, sources or artifacts
        output_mode = self.block_config.get("outputMode", None)
        if output_mode == 'NONE':
            pass  # Do not store the output anywhere
        elif output_mode == "SAVE_TO_STATE":
            output_key = self.block_config["outputStateKey"]
            self.turn.state_set(output_key, output_dict)
            logger.info(f"Saved React block output to state key '{output_key}'")
        elif output_mode == "SAVE_TO_SCRATCHPAD":
            output_key = self.block_config["outputScratchpadKey"]
            self.sequence_context.scratchpad[output_key] = output_dict
            logger.info(f"Saved React block output to scratchpad key '{output_key}'")
        elif output_mode == "ADD_TO_MESSAGES":
            tool_outputs_msg: ChatMessage = {
                "role": "tool",
                "toolOutputs": [{
                    "callId": tool_call.id,
                    "output": json.dumps(output_dict),
                    "parts": parts,
                }]
            }
            self.sequence_context.generated_messages.append(tool_outputs_msg)
        else:
            raise ValueError(f"Unsupported output mode for Manual tool call block: {output_mode}")
    
    def _do_run_tool(self, llm_tool: LLMTool, call, context):
        """Executes a single tool safely, returning output, sources, artifacts, and trace."""
        artifacts = []
        sources = []
        output_dict = {}
        parts = []
        tool_validation_requests = None
        memory_fragment = None
        trace = SpanBuilder("DKU_MANAGED_TOOL_CALL")
        trace.begin(int(time.time() * 1000))
    
        try:
            logger.info("Invoking tool %s", call)
    
            tool_output = llm_tool.dku_tool.run(
                input=call.arguments,
                context=context,
                subtool_name=llm_tool.dku_subtool_name,
            )
    
            if tool_output.get("error"):
                raise Exception(tool_output["error"])
    
            logger.info("Tool output: %s", tool_output)
    
            # Unpack results
            output_dict = tool_output.get("output", {})
            sources = tool_output.get("sources", [])
            artifacts = tool_output.get("artifacts", [])
            parts = tool_output.get("parts", [])
            tool_validation_requests = tool_output.get("toolValidationRequests")
            memory_fragment = tool_output.get("memoryFragment")
    
            if "trace" in tool_output:
                trace.append_trace(tool_output["trace"])
    
        except Exception as e:
            logger.exception("Tool call failed: %s", call)
            output_dict = {"error": str(e)}
    
            # Attempt to enrich trace with error details
            try:
                trace.inputs["input"] = call.arguments
                trace.outputs["error"] = str(e)
                if llm_tool is not None:
                    trace.attributes["toolId"] = llm_tool.dku_tool.tool_id
                    trace.attributes["toolProjectKey"] = llm_tool.dku_tool.project_key
                    trace.attributes["toolType"] = llm_tool.dku_tool.get_settings()._settings.get("type", "unknown")
            except Exception:
                logger.warning("Failed to append error details to tool trace", exc_info=True)
    
        finally:
            trace.end(int(time.time() * 1000))
    
        return output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, trace