import copy
import hashlib
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import List, Dict, Collection, Optional, Union

from dataiku.generated_sources.com.dataiku.dip.llm.online.llm_client.tool_output import ToolOutput
from dataiku.generated_sources.com.dataiku.dip.llm.online.llm_client.function_tool_call import FunctionToolCall
from dataiku.llm.python.tools_using_2 import _tool_calls_from_chunks, BaseLLMTool, LLMTool
from dataiku.llm.python.types import AgentHierarchyEntry, ChatMessage, ToolHierarchyEntry, ToolValidationRequest
from dataikuapi.dss.agent_tool import DSSAgentTool
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter
from dataikuapi.dss.llm_tracing import SpanBuilder
from . import StreamableLLMBlock
from .. import NextBlock
from ..prompts import DEFAULT_REACT_SYSTEM_PROMPT
from ..types import ToolCallWithPotentialValidation, ToolCallValidationInfo
from ..utils import _validate_and_parse_tool_call, interpolate_cel, default_if_blank, tool_has_been_called
from ...utils import get_completion_query_safe_for_logging
from ..virtual_tools import GET_STATE_NAME, GET_STATE_DESCRIPTOR, SET_STATE_NAME, SET_STATE_DESCRIPTOR, GET_SCRATCHPAD_DESCRIPTOR, SET_SCRATCHPAD_DESCRIPTOR, \
    GET_SCRATCHPAD_NAME, SET_SCRATCHPAD_NAME

logger = logging.getLogger("dku.agents.blocks_graph")


@dataclass
class LLMSubAgent(BaseLLMTool):
    sub_agent_tool_ref: Dict


LLMToolOrSubAgent = Union[LLMTool, LLMSubAgent]


class ReactBlockHandler(StreamableLLMBlock):

    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)
        self.block_turn_generated_messages: List[ChatMessage] = []
        self.iteration_number = 0
        self.forced_args_cache = None

    def load_tools(self) -> Dict[str, LLMToolOrSubAgent]:
        current_tools = []
        subagent_tools = []
        for used_tool in self.block_config["tools"]:
            if used_tool.get("type") == "EXPLICIT_TOOL":
                ptool = self.turn.agent.load_or_get_tool(used_tool)
                current_tools.append(ptool)
            elif used_tool.get("type") == "SUB_AGENT":
                subagent_tools.append(self._load_sub_agent_tool(used_tool))
            else:
                logger.warning(f"Unknown tool type: {used_tool.get('type')}")

        tools_by_name = {}

        for prepared_tool in current_tools:
            for llm_tool in prepared_tool.llm_tools:
                tools_by_name[llm_tool.llm_tool_name] = llm_tool

        for subagent_tool in subagent_tools:
            tools_by_name[subagent_tool.llm_tool_name] = subagent_tool

        # TODO @power-agents Maybe add state-aware and scratchpad-aware tools here?

        return tools_by_name

    def _load_sub_agent_tool(self, sub_agent_tool_ref: Dict):
        # TODO @power-agents Maybe validate all params first, before making the tool object

        # Append a hash of the sub agent config to the tool name to avoid collisions if multiple copies of the same agent are used in the same block
        sub_agent_config_hash = hashlib.sha256(json.dumps(sub_agent_tool_ref, sort_keys=True).encode("utf-8")).hexdigest()[:6]
        tool_name = "call_sub_agent_" + sub_agent_config_hash  # TODO @power-agents Maybe add the name of the sub-agent here, to aid debugging?

        tool_description = "Asks a question to an agent."
        additional_description = sub_agent_tool_ref.get("additionalDescription")
        if additional_description:
            tool_description += "\n\n" + additional_description

        llm_tool = LLMSubAgent(
            llm_tool_name=tool_name,
            llm_descriptor={
                "type": "function",
                "function": {
                    "name": tool_name,
                    "description": tool_description,
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "question": {
                                "type": "string",
                                "description": "the question to ask"
                            }
                        },
                        "required": ["question"]
                    }
                }
            },
            sub_agent_tool_ref=sub_agent_tool_ref,
        )
        return llm_tool

    def _call_sub_agent(self, sub_agent_tool_ref: Dict, call_arguments: Dict):
        subagent_trace = SpanBuilder("DKU_AGENT_SUBAGENT_CALL")
        subagent_trace.begin(int(time.time() * 1000))
        output_dict = {}
        parts = []
        all_sources = []
        all_artifacts = []
        tool_validation_requests = None
        memory_fragment = None

        subagent_trace.attributes["agentId"] = sub_agent_tool_ref["agentId"]
        subagent_trace.inputs["input"] = call_arguments

        try:
            llm = self.agent.project.get_llm("agent:" + sub_agent_tool_ref["agentId"])
            completion = llm.new_completion()

            # TODO @power-agents If restoring from HITL, then restore messages here instead of adding the below messages

            if system_prompt_prepend := sub_agent_tool_ref.get("systemPromptPrepend"):
                completion.cq["messages"].append({
                    "role": "system",
                    "content": system_prompt_prepend
                })

            message_from_agent = call_arguments.get("question")
            if not message_from_agent:
                raise ValueError("No question received: please provide a question for the agent to answer.")
            completion.cq["messages"].append({
                "role": "user",
                "content": message_from_agent
            })

            if sub_agent_tool_ref.get("forwardContext", True):
                completion.cq["context"] = self.turn.current_merged_context

            logger.debug(f"About to run completion: {get_completion_query_safe_for_logging(completion.cq)}")
            logger.debug("With settings: %s" % completion.settings)

            accumulated_text_output = ""
            last_memory_fragment = None
            for chunk in completion.execute_streamed():
                artifacts = chunk.data.get("artifacts", [])
                all_artifacts.extend(artifacts)

                sources = chunk.data.get("additionalInformation", {}).get("sources", [])
                all_sources.extend(sources)

                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    subagent_trace.append_trace(chunk.trace)

                    if chunk.data.get("contextUpsert") is not None:
                        pass  # Ignoring context upserts from the sub-agent, because it's unrelated to the top-level conversation

                else:
                    if chunk.text is not None:
                        accumulated_text_output += chunk.text
                    elif chunk.type == "event":
                        pass  # Ignoring events coming from the sub-agent  # TODO @power-agents we could potentially send these? Maybe with a hierarchy?
                    elif chunk.type == "content":
                        if "memoryFragment" in chunk.data:
                            # Ignoring memory fragments coming from the sub-agent, apart from the HITL one
                            # because memory fragments are only useful for conversations, and the sub-agent call is not a conversation
                            last_memory_fragment = chunk.data["memoryFragment"]
                        elif "toolValidationRequests" in chunk.data:
                            tool_validation_requests = chunk.data["toolValidationRequests"]
                    else:
                        logger.warning("Unknown chunk type from delegated agent: %s, ignoring it" % chunk)

            output_dict["response"] = accumulated_text_output
            if tool_validation_requests:
                memory_fragment = last_memory_fragment

            if tool_validation_requests:
                # TODO @power-agents Add support for HITL for sub-agent calls (wrap the memory fragment here, with the accumulated text and sources)
                tool_validation_requests = None
                memory_fragment = None
                raise Exception("Sub agent returned a tool validation request. Human tool validations are not currently supported for sub-agents.")

        except Exception as e:
            logger.exception(f"Sub agent tool call failed: {sub_agent_tool_ref}")
            output_dict = {"error": str(e)}

        subagent_trace.outputs["output"] = output_dict
        subagent_trace.end(int(time.time() * 1000))

        return output_dict, parts, all_sources, all_artifacts, tool_validation_requests, memory_fragment, subagent_trace

    def process_stream(self, trace: SpanBuilder):
        logger.info("Standard React block starting with config %s" % self.block_config)

        max_loop_iterations = self.block_config.get("maxLoopIterations", 25)

        self.block_turn_generated_messages = []

        while True:
            self.iteration_number += 1
            if self.iteration_number > max_loop_iterations:
                raise Exception(f"React block exceeded max number of loop iterations ({max_loop_iterations})")

            with trace.subspan("DKU_AGENT_REACT_ITERATION") as iteration_trace:
                logger.info(f"Starting react block iteration: {self.iteration_number}")
                iteration_trace.attributes["iterationNumber"] = self.iteration_number

                yield {"chunk": {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}

                completion = self.new_completion()
                cel_engine = self.standard_cel_engine()
                tools_by_name = self.load_tools()

                completion.with_context(self.turn.current_merged_context)

                completion._settings = copy.deepcopy(self.block_config.get("completionSettings", {}))
                completion.settings["tools"] = [llm_tool.llm_descriptor for llm_tool in tools_by_name.values()]
                if self.block_config.get("stateAware", False):
                    completion.settings["tools"].append(GET_STATE_DESCRIPTOR)
                    completion.settings["tools"].append(SET_STATE_DESCRIPTOR)
                if self.block_config.get("scratchpadAware", False):
                    completion.settings["tools"].append(GET_SCRATCHPAD_DESCRIPTOR)
                    completion.settings["tools"].append(SET_SCRATCHPAD_DESCRIPTOR)

                # Prompts and history
                self._interpolate_and_add_message(completion, "system", cel_engine,
                                                  self.block_config.get("systemPromptBeforeHistory"))

                if self.block_config.get("passConversationHistory", True):
                    completion.cq["messages"].extend(self.turn.initial_messages)
                    completion.cq["messages"].extend(self.sequence_context.generated_messages)

                # Even if we don't pass history, of course the messages that we generated during this
                # turn of the block must go in
                completion.cq["messages"].extend(self.block_turn_generated_messages)

                self._interpolate_and_add_message(completion, "system", cel_engine,
                                                  self.block_config.get("systemPromptAfterHistory", DEFAULT_REACT_SYSTEM_PROMPT))
                # If some forced tool args have been provided, we need to instruct the LLM to use them
                # in addition to actually forcing them before calling the tool
                if forced_args_msg:=self._generate_forced_args_instructions(tools_by_name):
                    completion.with_message(forced_args_msg, "system")

                self._interpolate_and_add_message(completion, "user", cel_engine,
                                                  self.block_config.get("additionalUserMessage"))

                # Let's go
                with iteration_trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
                    accumulated_tool_call_chunks = []
                    accumulated_text_output = ""

                    # TODO @lavish-agents: Add relevant context
                    # TODO: Emit events (AGENT_THINKING / AGENT_TOOL_CALL / ... )

                    for ichunk in self._run_completion(completion, llm_trace):
                        if ichunk.text is not None:
                            if self.block_config["streamOutput"]:
                                yield {"chunk":  {"text": ichunk.text}}
                            accumulated_text_output += ichunk.text

                        if ichunk.memory_fragment:
                            memory_fragment_msg: ChatMessage = {
                                "role": "memoryFragment",
                                "memoryFragment": ichunk.memory_fragment
                            }
                            self.block_turn_generated_messages.append(memory_fragment_msg)

                        if ichunk.artifacts:
                            artifacts = ichunk.artifacts
                            for artifact in artifacts:
                                hierarchy: List = artifact.setdefault("hierarchy", [])
                                hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ?, "agentId": agent_id, "agentName": agent_name})
                            yield {"chunk": {"artifacts": artifacts}}

                        if ichunk.sources:
                            sources = ichunk.sources
                            for source in sources:
                                hierarchy: List = source.setdefault("hierarchy", [])
                                hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ? , "agentId": agent_id, "agentName": agent_name})
                            self.sequence_context.sources.extend(sources)

                        if ichunk.tool_call_chunks:
                            accumulated_tool_call_chunks.extend(ichunk.tool_call_chunks)

                    if self.block_config.get("streamOutput", False) and self.agent.config.get("newLineAfterBlockOutput", True):
                        yield {"chunk":  {"text": "\n"}}

                # Reassemble tool calls
                tool_calls = _tool_calls_from_chunks(accumulated_tool_call_chunks)
                self._upsert_forced_tool_calls_args(tool_calls, tools_by_name)
                logger.info("Gathered tool calls: %s" % tool_calls)

                self.sequence_context.last_text_output = accumulated_text_output
                logger.info("Response stored as last text output")

                # Save the output as requested
                # Note: although it looks very close to StreamableLLMBlock._handle_output_mode_at_end,
                # it is actually a bit different because since it's a loop, we don't directly append to
                # the sequence context
                output_mode = self.block_config.get("outputMode", None)
                if output_mode == 'NONE':
                    pass  # Do not store the output anywhere, just keep it in last_text_output
                elif output_mode == "ADD_TO_MESSAGES":
                    # Don't add empty assistant messages for loop iterations in which the LLM doesn't output any text
                    # Don't add separate assistant message if LLM output will already be included in a tool calls message below
                    if accumulated_text_output != "" and not tool_calls:
                        assistant_message: ChatMessage = {
                            "role": "assistant",
                            "content": accumulated_text_output
                        }
                        self.block_turn_generated_messages.append(assistant_message)
                        logger.info("Response added to generated messages")
                elif output_mode == "SAVE_TO_STATE":
                    output_key = self.block_config["outputStateKey"]
                    self.turn.state_set(output_key, accumulated_text_output)
                    logger.info("Saved React block output to state key '%s'", output_key)
                elif output_mode == "SAVE_TO_SCRATCHPAD":
                    output_key = self.block_config["outputScratchpadKey"]
                    self.sequence_context.scratchpad[output_key] = accumulated_text_output
                    logger.info("Saved React block output to scratchpad key '%s'", output_key)
                else:
                    raise ValueError("Unsupported output mode for React block: %s" % output_mode)

                if len(tool_calls) > 0:
                    with iteration_trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:
                        tool_calls = [_validate_and_parse_tool_call(tc) for tc in tool_calls]

                        # Generate the assistant message with tool calls that we'll add in the history for the next run of the loop
                        tool_calls_msg: ChatMessage = {
                            "role": "assistant",
                            "toolCalls": [tc.dku_tool_call for tc in tool_calls]
                        }
                        if accumulated_text_output:
                            tool_calls_msg["content"] = accumulated_text_output
                        self.block_turn_generated_messages.append(tool_calls_msg)

                        tools_require_validation = yield from self._call_tools(tool_calls, tools_by_name, tools_trace, None, None)
                        if tools_require_validation:
                            return

                    # The tools may have updated the state, so we evaluate the exit conditions now
                    with iteration_trace.subspan("DKU_EVALUATE_BLOCK_CONDITIONS") as conditions_trace:
                        logger.info("Evaluating exit conditions after tool calls")
                        # _get_next_block will account for both sequence_context.generated_messages and block_turn_generated_messages
                        # when evaluating "tool called" exit conditions
                        next_block = self._get_next_block()
                        logger.info("After tool calls, next block is %s" % next_block)
                        conditions_trace.attributes["evaluatedNextBlock"] = next_block

                        if next_block is None:
                            conditions_trace.attributes["evaluationOutcome"] = "STAY"
                            logger.info("Continuing the React loop")
                        else:
                            conditions_trace.attributes["evaluatedNextBlock"] = next_block
                            logger.info("Exiting the React loop to go to block %s" % next_block)
                            self.sequence_context.generated_messages.extend(self.block_turn_generated_messages)
                            yield NextBlock(id=next_block)
                            return

                else:
                    logger.info("No tool to call, React loop is done")

                    logger.info("React block is over, going to default next block if any")

                    # We still need to evaluate exit conditions here (even if we have already done it after tool calls),
                    # because the last iteration of the loop may not have had tool calls
                    with iteration_trace.subspan("DKU_EVALUATE_BLOCK_CONDITIONS") as conditions_trace:
                        logger.info("Evaluating exit conditions at end of React loop")
                        # _get_next_block will account for both sequence_context.generated_messages and block_turn_generated_messages
                        # when evaluating "tool called" exit conditions
                        next_block = self._get_next_block()
                        logger.info("At end, next block is %s" % next_block)

                        if next_block is None:
                            logger.info("No next block from conditions, looking at default")
                            next_block = self.block_config.get("defaultNextBlock")

                    logger.info("React block next block is %s" % next_block)
                    if next_block is None:
                        logger.info("No next block, seems like we are done for the turn")

                    self.sequence_context.generated_messages.extend(self.block_turn_generated_messages)
                    yield NextBlock(id=next_block)
                    return


    # Lazy load evaluated forced args
    def _get_all_evaluated_forced_args(self):
        if self.forced_args_cache is None:
            self.forced_args_cache = {}
            cel_engine = None
            for tool_config in self.block_config.get("tools", []):
                if tool_config.get("enableSetArgs") and (set_args:=tool_config.get("setArgs")):
                    # lazy load cel engine
                    if cel_engine is None:
                        cel_engine = self.standard_cel_engine()
                    # evaluate forced args
                    tool_overrides_key = tool_config['toolRef'] + '_' + tool_config.get('subtoolName', '')
                    evaluated_args = {}
                    for tool_arg in set_args:
                        if len(tool_arg.get("key", '')) > 0:
                            try:
                                evaluated_args[tool_arg["key"]] = cel_engine.evaluate(tool_arg.get("value", "''"))
                            except Exception as e:
                                raise Exception("Error evaluating forced argument with key '%s': %s" % (tool_arg["key"], e))
                    # collect evaluated args
                    self.forced_args_cache[tool_overrides_key] = evaluated_args

        return self.forced_args_cache

    def _get_evaluated_forced_args(self, tool_ref, subtool_name = ''):
        return self._get_all_evaluated_forced_args().get(tool_ref + '_' + (subtool_name or ''), {})

    # Upsert forced tool calls args
    def _upsert_forced_tool_calls_args(self, tool_calls:list[FunctionToolCall], tools_by_name) -> list[FunctionToolCall]:
        for tc in tool_calls:
            if tc.get("type") != 'function':
                continue
            
            function_name = tc.get('function', {}).get('name')
            if not function_name:
                continue
            
            tool = tools_by_name.get(function_name)
            if not isinstance(tool, LLMTool):
                continue # virtual tools can't have input overrides

            if forced_args:= self._get_evaluated_forced_args(tool.dku_tool_ref, tool.dku_subtool_name):
                # upsert arguments with forced args
                original_arguments = tc['function']['arguments']
                updated_args = json.loads(original_arguments)
                updated_args.update(forced_args)
                tc['function']['arguments'] = json.dumps(updated_args)
                logger.info("Arguments override for tool: %s subtool: %s. Original args: %s. Updated args: %s", tool.dku_tool_ref, tool.dku_subtool_name, original_arguments, tc['function']['arguments'])

    def _generate_forced_args_instructions(self, tools_by_name):
        tools_instructions = []
        for tool_name in tools_by_name.keys():

            tool = tools_by_name.get(tool_name)
            if not isinstance(tool, LLMTool):
                continue # virtual tools can't have input overrides

            if forced_args := self._get_evaluated_forced_args(tool.dku_tool_ref, tool.dku_subtool_name):
                tools_instructions.append("- " + tool.llm_tool_name + ": " + json.dumps(forced_args))
        
        if len(tools_instructions)>0:
            tools_instructions.insert(0, "Some available tools require specific input values if used. Here are the requirements by tool:")
            return '\n'.join(tools_instructions)
        else:
            return None

    # Called by the turn if we enter the turn with pending tool calls for this block
    def play_pending_tool_calls(self, tool_calls: Collection[ToolCallWithPotentialValidation], parent_trace: SpanBuilder, validation_infos_map: Optional[dict[str, ToolCallValidationInfo]], tool_outputs_map: Optional[dict[str, ToolOutput]]):

        logger.info("Playing pending tool calls: %s" % tool_calls)

        current_tools = self.load_tools()
        need_more_validations = yield from self._call_tools(tool_calls, current_tools, parent_trace, validation_infos_map, tool_outputs_map)

        return need_more_validations

    def _call_tools(self, tool_calls: Collection[ToolCallWithPotentialValidation], current_tools: Dict[str, LLMToolOrSubAgent], parent_trace: SpanBuilder, validation_infos_map: Optional[dict[str, ToolCallValidationInfo]], tool_outputs_map: Optional[dict[str, ToolOutput]]):
        logger.info(f"I have {len(tool_calls)} tool(s) to call")
        parent_trace.attributes["nbToolCalls"] = len(tool_calls)

        agent_id = self.agent.config.get("agentId", "")
        agent_name = self.agent.config.get("agentName", "")
        agent_level_hierarchy_entry: AgentHierarchyEntry = {"type": "AGENT", "agentLoopIteration": self.iteration_number, "agentId": agent_id, "agentName": agent_name}

        own_validation_requests = []
        tool_calls_to_run = []

        if validation_infos_map is None:
            # 1) On a "normal" iteration of the agentic loop, we need to check tools that require validation and:
            # - run the tools that don't require it
            # - send validation requests for those that require it
            for tool_call in tool_calls:

                if tool_call.name not in current_tools:
                    # Tool not found -> probably a virtual tool. No concept of validation
                    tool_calls_to_run.append(tool_call)
                    continue

                llm_tool = current_tools[tool_call.name]
                if isinstance(llm_tool, LLMTool) and llm_tool.require_validation:
                    display_tool_name = llm_tool.dku_tool_name
                    if llm_tool.dku_subtool_name:
                        display_tool_name += f" / {llm_tool.dku_subtool_name}"
                    tool_call_description = "Do you want to execute this tool call?"
                    if (
                        llm_tool.dku_tool is not None
                        and (generated_tool_call_description := llm_tool.dku_tool.describe_tool_call(tool_call.arguments, llm_tool.llm_descriptor, context=self.turn.current_merged_context, subtool_name=llm_tool.dku_subtool_name))
                    ):
                        tool_call_description = generated_tool_call_description
                    validation_request: ToolValidationRequest = {
                        "id": f"validation-{tool_call.id}",
                        "blockId": self.block_config["id"],
                        "hierarchy": [agent_level_hierarchy_entry],
                        "message": tool_call_description,
                        "toolRef": llm_tool.dku_tool_ref,
                        "toolName": display_tool_name,
                        "toolType": llm_tool.dku_tool_type,
                        "toolDescription": llm_tool.llm_descriptor["function"].get("description") or "",
                        "toolInputSchema": llm_tool.llm_descriptor["function"].get("parameters") or {},
                        "allowEditingInputs": llm_tool.allow_editing_inputs,
                        "toolCall": tool_call.dku_tool_call,
                    }
                    own_validation_requests.append(validation_request)
                else:
                    tool_calls_to_run.append(tool_call)

            # tool_calls_to_run only contains tool calls that don't require validation
            validation_infos_map = {}
        else:
            # 2) On the first iteration of the agentic loop right after resuming after an interruption:
            # The tool calls that did not require validation have already been run before the interruption (and the results of these runs should be in tool_outputs_map)
            # => The list of tool_calls have already been curated when restoring the agentic loop state, and here it contains only those still pending after the interruption.
            #
            # These tool calls either:
            # - did not start running yet because they required validation before the interruption, and the validation request was already sent
            # - started running before the interruption but were interrupted mid-execution by a nested tool validation request: in this case the execution already started so it was either already approved or did not require validation
            # => In both cases, validation_infos_map is expected to already contain validation responses for all of them and there is no need to request validation for them at this point.
            tool_calls_to_run = tool_calls

        # Call the tools in parallel
        if tool_outputs_map is None:
            tool_outputs_map = {}

        num_parallel_threads = self.block_config.get("maxParallelToolExecutions", 2)
        if num_parallel_threads < 1:
            raise Exception(f"Parameter maxParallelToolExecutions must be > 0 (received: {num_parallel_threads})")
        with ThreadPoolExecutor(num_parallel_threads) as executor:
            futures = []

            for tool_call in tool_calls_to_run:
                yield {"chunk": {"type": "event", "eventKind": "AGENT_TOOL_START", "eventData": {"toolName": tool_call.name}}}
                futures.append(executor.submit(self._call_one_tool, tool_call, current_tools, validation_infos_map, self.turn.current_merged_context))

        inner_validation_requests = []
        inner_memory_fragments_wrapped = []
        for tool_call, future in zip(tool_calls_to_run, futures):
            output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, tool_trace = future.result()

            # Update hierarchies
            if tool := current_tools.get(tool_call.name):
                if isinstance(tool, LLMSubAgent):
                    tool_ref = ""
                    tool_name = "Call sub-agent"
                else:
                    tool_ref = tool.dku_tool_ref
                    tool_name = tool.dku_tool_name
            else:
                # If the LLM made a bad tool call, just add "Unknown" into the hierarchy
                tool_ref = "Unknown"
                tool_name = "Unknown"
            tool_level_hierarchy_entry: ToolHierarchyEntry = {"type": "TOOL", "toolRef": tool_ref, "toolName": tool_name, "toolCallId": tool_call.id}

            for artifact in artifacts:
                hierarchy: List = artifact.setdefault("hierarchy", [])
                hierarchy.insert(0, tool_level_hierarchy_entry)
                hierarchy.insert(0, agent_level_hierarchy_entry)

            # Stream artifacts inline immediately
            yield {"chunk": {"type": "content", "artifacts": artifacts}}

            if tool_validation_requests:
                # tool returning validation requests may also return (partial) output_dict and sources but we must explicitly ignore them in this block
                # it is the responsibility of the tool to store partial outputs and sources in its memory fragment, to aggregate them when resuming, and to return the fully aggregated output and sources at the end
                # if we accounted for partial sources at the agent level, we'd be aggregating the same sources several times on each interruption of the tool run
                # if we accounted for the partial output_dict here, we wouldn't even know how to aggregate it on the next turn as this logic is known by the tool only
                #
                # artifacts are handled differently on purpose: they are streamed immediately, regardless of whether the tool is done or simply interrupted
                for tvr in tool_validation_requests:
                    if not tvr.get("hierarchy"):
                        tvr["hierarchy"] = []
                    tvr["hierarchy"].insert(0, tool_level_hierarchy_entry)
                    tvr["hierarchy"].insert(0, agent_level_hierarchy_entry)

                    tvr["blockId"] = self.block_config["id"]

                inner_validation_requests.extend(tool_validation_requests)

                if memory_fragment:
                    wrapped_memory_fragment: ChatMessage = {"role": "memoryFragment", "memoryFragment": memory_fragment, "memoryFragmentTarget": tool_level_hierarchy_entry}
                    inner_memory_fragments_wrapped.append(wrapped_memory_fragment)
            else:
                # tool run is truly over, we can now process its (full) output and sources
                # artifacts from the tool have already been streamed a few lines above

                # Add tool output as a message in the messages chain
                tool_outputs_map[tool_call.id] = {
                    "callId": tool_call.id,
                    "output": json.dumps(output_dict),
                    "parts": parts,
                }

                for source in sources:
                    hierarchy: List = source.setdefault("hierarchy", [])
                    hierarchy.insert(0, tool_level_hierarchy_entry)
                    hierarchy.insert(0, agent_level_hierarchy_entry)

                # Store sources, to emit later in the streaming footer
                logger.info("Storing sources %s" % sources)
                self.sequence_context.sources.extend(sources)

            # Append the trace from the tool to the agent's trace
            if tool_trace:
                parent_trace.append_trace(tool_trace)

        # Generate the tool message that we'll add in the history for the next run of the loop
        ordered_tool_outputs = []
        tool_call_message = (self.turn.initial_messages + self.sequence_context.generated_messages + self.block_turn_generated_messages)[-1]
        for tc in tool_call_message.get("toolCalls") or []:
            tool_call_id = tc.get("id")
            if tool_call_id in tool_outputs_map:
                ordered_tool_outputs.append(tool_outputs_map[tool_call_id])

        tool_outputs_msg: ChatMessage = {
            "role": "tool",
            "toolOutputs": ordered_tool_outputs
        }

        self.block_turn_generated_messages.append(tool_outputs_msg)

        # Interrupt the agentic loop to send the validation requests if there are any
        all_validation_requests = own_validation_requests + inner_validation_requests
        if all_validation_requests:
            yield from self._send_validation_requests(all_validation_requests, inner_memory_fragments_wrapped, parent_trace)
            return True

        return False

    def _call_one_tool(self, call: ToolCallWithPotentialValidation, current_tools: Dict[str, LLMToolOrSubAgent], validation_infos_map: dict[str, ToolCallValidationInfo], context):

        if call.name == GET_STATE_NAME:
            trace = SpanBuilder.create_event("DKU_VIRTUAL_TOOL_CALL")
            trace.attributes["toolId"] = GET_STATE_NAME
            trace.inputs["input"] = call.arguments

            key = call.arguments.get("key")
            value = self.turn.state_get(key, None)

            trace.outputs["output"] = value
            return {"value": value}, [], [], [], None, None, trace

        if call.name == SET_STATE_NAME:
            trace = SpanBuilder.create_event("DKU_VIRTUAL_TOOL_CALL")
            trace.attributes["toolId"] = SET_STATE_NAME
            trace.inputs["input"] = call.arguments

            key = call.arguments.get("key")
            value = call.arguments.get("value")
            self.turn.state_set(key, value)

            return {"status": "ok"}, [], [], [], None, None, trace

        if call.name == GET_SCRATCHPAD_NAME:
            trace = SpanBuilder.create_event("DKU_VIRTUAL_TOOL_CALL")
            trace.attributes["toolId"] = GET_SCRATCHPAD_NAME
            trace.inputs["input"] = call.arguments

            key = call.arguments.get("key")
            value = self.sequence_context.scratchpad.get(key, None)

            trace.outputs["output"] = value
            return {"value": value}, [], [], [], None, None, trace

        if call.name == SET_SCRATCHPAD_NAME:
            trace = SpanBuilder.create_event("DKU_VIRTUAL_TOOL_CALL")
            trace.attributes["toolId"] = SET_SCRATCHPAD_NAME
            trace.inputs["input"] = call.arguments

            key = call.arguments.get("key")
            value = call.arguments.get("value")
            self.sequence_context.scratchpad[key] = value

            return {"status": "ok"}, [], [], [], None, None, trace

        if llm_tool := current_tools.get(call.name):
            if isinstance(llm_tool, LLMSubAgent):
                return self._call_sub_agent(llm_tool.sub_agent_tool_ref, call.arguments)

        # call is of the form returned by _tool_calls_from_chunk
        parts = []
        sources = []
        artifacts = []
        tool_validation_requests = None
        memory_fragment = None

        trace = SpanBuilder("DKU_MANAGED_TOOL_CALL")  # Create a fake trace, in case tool errors
        trace.begin(int(time.time() * 1000))
        llm_tool = None
        try:
            logger.info("Invoking tool %s" % call)

            llm_tool = current_tools[call.name]
            assert isinstance(llm_tool, LLMTool)  # otherwise it is a subagent and we have already returned a few lines above

            if llm_tool.require_validation:
                if call.id not in validation_infos_map:
                    raise Exception("Tool call needs a validation, but no validation was provided for this tool call")
                if not validation_infos_map[call.id].validated:
                    raise Exception("Tool call was rejected by user")

            tool_output = llm_tool.dku_tool.run(
                call.arguments,
                context=context,
                subtool_name=llm_tool.dku_subtool_name,
                memory_fragment=call.memory_fragment,
                tool_validation_responses=call.tool_validation_responses,
                tool_validation_requests=call.tool_validation_requests
            )

            if tool_output.get("error"):
                raise Exception(tool_output["error"])

            logger.info("Tool output: %s" % tool_output)

            if "trace" in tool_output:
                trace.append_trace(tool_output["trace"])

            output_dict = tool_output["output"]
            parts = tool_output.get("parts", [])
            sources = tool_output.get("sources", [])
            artifacts = tool_output.get("artifacts", [])
            tool_validation_requests = tool_output.get("toolValidationRequests")
            memory_fragment = tool_output.get("memoryFragment")
            trace = tool_output.get("trace", None)

        except Exception as e:
            logger.exception(f"Tool call failed: {call}")

            # TODO @new-agent-loop Add a flag to the tool, to say whether errors are fatal or not.
            #                      Add UI to set the flag when adding the tool to the agent.
            #                      If the flag is set, then throw here (or return the error here, and throw later, to keep the trace).

            output_dict = {"error": str(e)}

            # Try to add more info to the trace
            try:
                trace.inputs["input"] = call.arguments
                trace.outputs["error"] = str(e)
                if isinstance(llm_tool, LLMTool):
                    trace.attributes["toolId"] = llm_tool.dku_tool.tool_id
                    trace.attributes["toolProjectKey"] = llm_tool.dku_tool.project_key
                    trace.attributes["toolType"] = llm_tool.dku_tool.get_settings()._settings["type"]
            except Exception:
                logger.exception("Error getting more info about tool call failure")

            trace.end(int(time.time() * 1000))

        return output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, trace

    def _send_validation_requests(self, validation_requests: list[ToolValidationRequest], nested_memory_fragments_messages: list[ChatMessage], parent_trace: SpanBuilder):
        pending_validation_trace = parent_trace.event("DKU_AGENT_PENDING_TOOL_CALLS_VALIDATIONS")
        pending_validation_trace.attributes["nbPendingValidationRequests"] = len(validation_requests)

        yield {
            "chunk": {
                "type": "content",
                "memoryFragment": {
                    "agentLoopIteration": self.iteration_number,
                    "messages": self.sequence_context.generated_messages + self.block_turn_generated_messages + nested_memory_fragments_messages,
                    "stashedSources": self.sequence_context.sources
                }
            }
        }
        yield {"chunk": {"type": "content", "toolValidationRequests": validation_requests}}

        logger.info("Some tool calls require human approval, React loop exits here")
        yield DSSLLMStreamedCompletionFooter({
                "finishReason": "tool_validation_requests",
            }
        )

    def _get_next_block(self):
        for condition in self.block_config["exitConditions"]:
            if condition["type"] == "STATE_HAS_KEYS":
                logger.info("STATE_HAS_KEYS condition")

                has_all_keys = all(self.turn.state_get(k, None) is not None for k in condition["stateKeys"])
                if has_all_keys:
                    logger.info("State has requested keys, going to block %s" % (condition.get("nextBlock")))
                    return condition.get("nextBlock")

            elif condition["type"] == "SCRATCHPAD_HAS_KEYS":
                logger.info("SCRATCHPAD_HAS_KEYS condition")

                has_all_keys = all(self.sequence_context.scratchpad.get(k, None) is not None for k in condition["scratchpadKeys"])
                if has_all_keys:
                    logger.info("Scratchpad has requested keys, going to block %s" % (condition.get("nextBlock")))
                    return condition.get("nextBlock")
            elif condition["type"] == "EXPRESSION":
                logger.info("EXPRESSION-based exit condition")

                # TODO: handle non-CEL
                cel_expression = condition["expression"]["expression"]

                cel_engine = self.standard_cel_engine()
                eval_result = cel_engine.evaluate(cel_expression)

                logger.info("CEL expression %s evaluated to %s" % (cel_expression, eval_result))

                if eval_result:
                    return condition.get("nextBlock", None)

            elif condition["type"] == "TOOLS_CALLED":
                tools_by_name = self.load_tools()
                tools_by_ref = {tool.dku_tool_ref: tool for tool in tools_by_name.values() if isinstance(tool, LLMTool)}
                tool_ref_ids = {
                    tool_ref.get("toolRef") if isinstance(tool_ref, dict) else tool_ref
                    for tool_ref in condition["toolRefs"]
                }
                tool_names = {
                    tools_by_ref[tool_ref_id].llm_tool_name.split("__", 1)[0]
                    for tool_ref_id in tool_ref_ids
                    if tool_ref_id in tools_by_ref
                }
                # we want the exit condition to account for tools called before the beginning of the react loop
                # and also for the tools called during the loop
                # when the exit condition is evaluated, block turn generated messages haven't yet been added to the sequence context
                all_generated_messages = self.sequence_context.generated_messages + self.block_turn_generated_messages
                were_any_tools_called = any(
                    tool_has_been_called(all_generated_messages, tool_name)
                    for tool_name in tool_names
                )
                if were_any_tools_called:
                    logger.info("A tool was called, going to block %s" % (condition.get("nextBlock")))
                    return condition.get("nextBlock")
            else:
                raise Exception("Unknown exit condition type: %s" % condition["type"])

        return None
