# New LangGraph-free implementation of the Agentic loop
# 
# By switching to explicit imperative code that we control, rather than have it mediated through a LangGraph
# graph, we can better control the execution Flow, notably to perform various kinds of dynamic tool selection,
# early-abort, etc.
#

import copy
import hashlib
import json
import logging
import time
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Collection, Dict, List, Optional

import dataiku
from dataiku.llm.python import BaseLLM, utils
from dataiku.llm.python.types import (
    AgentHierarchyEntry,
    ChatMessage,
    CompletionSettings,
    FunctionTool,
    FunctionToolCall,
    MemoryFragment,
    SingleCompletionQuery,
    ToolHierarchyEntry,
    ToolOutput,
    ToolsUsingAgentSettings,
    ToolValidationRequest,
    ToolValidationResponse,
    UsedTool,
)
from dataiku.llm.python.utils import process_tool_call_chunk
from dataikuapi.dss.agent_tool import DSSAgentTool
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter
from dataikuapi.dss.llm_tracing import SpanBuilder

logger = logging.getLogger("dku.agents.tools_using_2")


class InternalCompletionChunk(object):
    def __init__(self, text, tool_call_chunks, artifacts, sources, memory_fragment):
        self.text = text
        self.tool_call_chunks = tool_call_chunks
        self.artifacts = artifacts
        self.sources = sources
        self.memory_fragment = memory_fragment


class PreparedTool(object):
    """
    This class represents a tool as defined by Dataiku (maps to a Dataiku object, can contain subtools)
    self.llm_tools represents the subtools as defined by an LLM
    """

    def __init__(self, tool: DSSAgentTool, tool_config: UsedTool):
        self.descriptor = tool.get_descriptor()
        tool_params = tool.get_settings().get_raw()
        tool_ref = tool_config["toolRef"]
        additional_description = tool_config.get("additionalDescription", "")
        tool_config_key = PreparedTool.get_tool_key(tool_config)
        tool_config_hash = hashlib.sha256(tool_config_key.encode("utf-8")).hexdigest()[:6]

        # The descriptor of the tool(s), in a form that completion.settings["tools"] takes
        self.llm_tools: List[LLMTool] = []

        if self.descriptor.get("multiple", False):
            for subtool_descriptor in self.descriptor.get("subtools", []):
                if subtool_descriptor.get("enabled", False):
                    llm_tool_name = self.descriptor["name"] + '__' + tool_config_hash + "__" + subtool_descriptor["name"]
                    self.llm_tools.append(LLMTool(
                        llm_tool_name=llm_tool_name,
                        llm_descriptor={
                            "type": "function",
                            "function": {
                                "name": llm_tool_name,
                                "description": (subtool_descriptor.get("description", "") + "\n\n" + additional_description).strip(),
                                "parameters": subtool_descriptor["inputSchema"]
                            }
                        },
                        dku_tool=tool,
                        dku_tool_ref=tool_ref,
                        dku_tool_name=tool_params["name"],
                        dku_tool_type=tool_params["type"],
                        dku_subtool_name=subtool_descriptor["name"],
                        require_validation=tool_params["requireHumanApproval"],
                        allow_editing_inputs=tool_params["allowEditingInputs"],
                    ))

        else:
            llm_tool_name = self.descriptor["name"] + '__' + tool_config_hash
            self.llm_tools.append(LLMTool(
                llm_tool_name=llm_tool_name,
                llm_descriptor={
                    "type": "function",
                    "function": {
                        "name": llm_tool_name,
                        "description": (self.descriptor.get("description", "") + "\n\n" + additional_description).strip(),
                        "parameters": self.descriptor["inputSchema"],
                    }
                },
                dku_tool=tool,
                dku_tool_ref=tool_ref,
                dku_tool_name=tool_params["name"],
                dku_tool_type=tool_params["type"],
                dku_subtool_name=None,
                require_validation=tool_params["requireHumanApproval"],
                allow_editing_inputs=tool_params["allowEditingInputs"],
            ))

    @staticmethod
    def get_tool_key(tool_config: UsedTool):
        return json.dumps(tool_config, sort_keys=True, ensure_ascii=True)


@dataclass
class BaseLLMTool:
    llm_tool_name: str
    llm_descriptor: FunctionTool


@dataclass
class LLMTool(BaseLLMTool):
    """
    This class represents a tool as defined by an LLM
    """
    dku_tool: DSSAgentTool
    dku_tool_ref: str
    dku_tool_name: str
    dku_tool_type: str
    dku_subtool_name: Optional[str]
    require_validation: bool
    allow_editing_inputs: bool


@dataclass
class ToolCallValidationInfo:
    validated: bool
    allow_editing_inputs: bool
    edited_arguments: Optional[str] = None


@dataclass
class ToolCall:
    dku_tool_call: FunctionToolCall
    id: str
    name: str
    arguments: dict
    tool_validation_responses: Optional[list[ToolValidationResponse]] = None
    tool_validation_requests: Optional[list[ToolValidationRequest]] = None
    memory_fragment: Optional[MemoryFragment] = None


def _validate_tool_call(ftc: FunctionToolCall) -> ToolCall:
    id = ftc.get("id")
    if not id:
        raise Exception(f"Tool call missing id: {ftc}")

    function = ftc.get("function") or {}
    name = function.get("name")
    if not name:
        raise Exception(f"Tool call missing name: {ftc}")
    raw_args = function.get("arguments") or "{}"
    try:
        args = json.loads(raw_args, strict=False)
    except json.JSONDecodeError:
        raise Exception(f"Invalid tool call arguments format: {ftc}")

    return ToolCall(ftc, id, name, args)


# Helper function to reassemble several tool calls from a LLM Mesh list of tool call chunks
def _tool_calls_from_chunks(chunks) -> list[FunctionToolCall]:
    if len(chunks) == 0:
        return []

    # Combine chunks into a list of tool calls
    tool_calls_list = []
    tool_calls_map = {}

    for tool_call_chunk in chunks:
        process_tool_call_chunk(tool_call_chunk, tool_calls_map, tool_calls_list)

    tool_calls_list.extend(tool_calls_map.values())

    return tool_calls_list


class ToolsUsingAgent(BaseLLM):
    def __init__(self):
        super().__init__()
        self.project = dataiku.api_client().get_default_project()

    def set_config(self, config: ToolsUsingAgentSettings, unused):
        self.config = config

    def load_tools(self, messages, context, tools_cache: Dict[str, PreparedTool]) -> Dict[str, LLMTool]:
        llm_tools = {}

        for used_tool in self.config["tools"]:
            tool_ref = used_tool.get("toolRef", False)
            if not tool_ref:
                logger.warning(f"Ignoring tool with empty toolRef: {used_tool}")
                continue

            if used_tool.get("disabled", False):
                logger.debug(f"Ignoring disabled tool: {used_tool}")
                continue

            logger.info(f"Will use tool {tool_ref}")

            tool_key = PreparedTool.get_tool_key(used_tool)
            if tool_key in tools_cache:
                prepared_tool = tools_cache[tool_key]
            else:
                dku_api_tool = self.project.get_agent_tool(tool_ref)
                prepared_tool = PreparedTool(dku_api_tool, used_tool)
                tools_cache[tool_key] = prepared_tool

            for llm_tool in prepared_tool.llm_tools:
                llm_tools[llm_tool.llm_tool_name] = llm_tool

        return llm_tools

    def process_stream(self, query: SingleCompletionQuery, settings: CompletionSettings, trace: SpanBuilder):
        # The ToolsUsingAgent class must serve multiple requests at the same time. Thus, we cannot store
        # anything on it. We therefore use a class representing a single request/turn (even if multiple
        # tool calling iterations happen during the turn).

        # NB: We intentionally ignore the query settings here, because we use settings from the agent instead

        turn = ToolsUsingAgentTurn(self, query)

        for c in turn.process_stream(trace):
            yield c


class ToolsUsingAgentTurn(object):
    """
    This class holds everything that happens during a single turn of the agent (i.e. a single call to the agent through the LLM Mesh API).

    The turn may itself represent several iterations of the main react loop, as represented by the big "while True" in the main method
    """

    def __init__(self, agent: ToolsUsingAgent, query: SingleCompletionQuery):
        self.agent = agent
        self.initial_messages: list[ChatMessage] = query["messages"]
        self.initial_context = query.get("context", {})

        self.context = copy.deepcopy(self.initial_context)

        self.dku_properties = {dku_property["name"]: dku_property["value"] for dku_property in agent.config["dkuProperties"]}

        self.max_loop_iterations = int(self.dku_properties.get("dku.agents.maxLoopIterations", 25))
        self.iteration_number = 0

        self.tools_cache = {}

        # Intermediate messages that were created during processing of this turn (mostly, these are tool_call / tool_output messages)
        self.generated_messages: list[ChatMessage] = []
        self.all_sources = []
        self.all_context_upserts = []

    def _collect_tool_validation_responses(self, validation_responses_map=None) -> dict[str, ToolValidationResponse]:
        """
        Recursively pop and process consecutive messages with role = "toolValidationResponses" to collect all validation responses.
        Returns the collected validation responses, in a map indexed by validationRequestId.
        Optionally takes a pre-existing mapping to update, for recursive calling.
        """
        if validation_responses_map is None:
            validation_responses_map = {}

        if not (self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses"):
            raise ValueError("Tool validation response not found in the chat history")

        validation_responses_message = self.initial_messages.pop()

        validation_responses = validation_responses_message.get("toolValidationResponses")
        if not validation_responses:
            raise ValueError("Invalid tool validation response was received")
        for tc in validation_responses:
            validation_responses_map[tc["validationRequestId"]] = tc

        if self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses":
            return self._collect_tool_validation_responses(validation_responses_map)
        else:
            return validation_responses_map

    def _collect_validation_requests(self) -> list[ToolValidationRequest]:
        """
        Pop and process a single, required, role = "toolValidationRequests" message to collect all validation requests.
        Returns the list of all collected validation requests.
        """
        if not (self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationRequests"):
            raise ValueError("Tool validation response was received, but the original tool validation request wasn't provided in the chat history")

        validation_requests_message = self.initial_messages.pop()
        validation_requests = validation_requests_message.get("toolValidationRequests")
        if not validation_requests:
            raise ValueError("Tool validation response was received, but the tool validation requests in the chat history are empty")

        return validation_requests

    def _triage_validation_requests(self, validation_requests: list[ToolValidationRequest]) -> tuple[list[ToolValidationRequest], dict[str, list[ToolValidationRequest]]]:
        """
        Process the list of all validation requests previously collected, to separate those relative to this agent from those relative to nested tool calls.
        Returns the collected validation requests, in two separate structures:
        - own_validation_requests: a list of the validation requests that originated in this agent
        - deferred_validation_requests: a map of the validation requests that originated in nested tool calls, indexed by the tool call id.
        """
        own_validation_requests = []
        deferred_validation_requests = defaultdict(lambda: [])
        for tvr in validation_requests:
            hierarchy = tvr["hierarchy"]
            if not hierarchy:
                raise ValueError("No hierarchy found in tool validation request")

            agent_level = hierarchy.pop(0)
            if agent_level["type"] != "AGENT":
                raise ValueError("Invalid hierarchy in tool validation request")

            # sanity check: agentLoopIteration must match the restored iteration number from the memory fragment
            if "agentLoopIteration" not in agent_level or agent_level["agentLoopIteration"] is None:
                raise ValueError("Missing iteration number in tool validation request")
            if self.iteration_number == 0:
                raise ValueError("Agentic loop iteration number wasn't properly restored")
            if self.iteration_number != agent_level["agentLoopIteration"]:
                raise ValueError("Invalid iteration number in tool validation request")

            if len(hierarchy) > 0:
                tool_level = hierarchy.pop(0)
                if tool_level["type"] != "TOOL":
                    raise ValueError("Invalid hierarchy in tool validation request")
                deferred_validation_requests[tool_level["toolCallId"]].append(tvr)
            else:
                own_validation_requests.append(tvr)

        return own_validation_requests, deferred_validation_requests

    @staticmethod
    def _aggregate_own_validation_infos(raw_validation_responses_map: dict[str, ToolValidationResponse], validation_requests: list[ToolValidationRequest]) -> dict[str, ToolCallValidationInfo]:
        """
        Reconcile the validation responses and validation requests to return an aggregated map of tool call validation infos indexed by tool call id.
        """
        validation_infos_map: dict[str, ToolCallValidationInfo] = {}

        validation_request_ids = {vr["id"] for vr in validation_requests}
        for response_validation_request_id in raw_validation_responses_map.keys():
            if response_validation_request_id not in validation_request_ids:
                # validation response without a matching validation request
                raise ValueError(f"Tool validation response was received, but the corresponding tool validation request {response_validation_request_id} is missing")

        for validation_request in validation_requests:
            validation_request_id = validation_request["id"]
            tool_call_id = validation_request["toolCall"]["id"]
            if validation_request_id not in raw_validation_responses_map:
                # validation request without a matching validation response
                raise ValueError(f"No response provided for tool validation request {validation_request_id}")
            else:
                validation_response = raw_validation_responses_map[validation_request_id]
                validation_infos_map[tool_call_id] = ToolCallValidationInfo(
                    validated=validation_response["validated"],
                    allow_editing_inputs=validation_request["allowEditingInputs"],
                    edited_arguments=validation_response.get("arguments")
                )

        return validation_infos_map

    def _collect_and_triage_memory_fragments(self) -> tuple[MemoryFragment, dict[str, MemoryFragment]]:
        """
        Pop a single, required, role = "memoryFragment" message, and extract the memory fragments from its nested structure.
        Returns the collected memory fragments, in two separate structures:
        - own_memory_fragment: the (single) memory fragment encoding the past state of this agentic loop
        - deferred_memory_fragments: a map of the memory fragments that originated in nested tool calls, indexed by toolCallId
        """
        if not (self.initial_messages and self.initial_messages[-1]["role"] == "memoryFragment"):
            raise ValueError("Tool validation response was received but the memory fragment wasn't provided")

        own_memory_fragment_message = self.initial_messages.pop()
        own_memory_fragment = own_memory_fragment_message.get("memoryFragment")
        if not own_memory_fragment:
            raise ValueError("Tool validation response was received but the memory fragment wasn't provided")

        # extract nested memory fragments wrapped in messages
        own_memory_fragment_messages = own_memory_fragment.get("messages")
        if not own_memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was empty")

        deferred_memory_fragments: dict[str, MemoryFragment] = {}
        while own_memory_fragment_messages:
            if own_memory_fragment_messages[-1]["role"] != "memoryFragment":
                break
            memory_fragment_message = own_memory_fragment_messages.pop()
            memory_fragment = memory_fragment_message.get("memoryFragment")
            memory_fragment_target = memory_fragment_message.get("memoryFragmentTarget")
            if not memory_fragment:
                raise ValueError("Nested memory fragment message was empty")
            if not memory_fragment_target:
                raise ValueError("Missing target for nested memory fragment")
            if memory_fragment_target["type"] != "TOOL":
                raise ValueError("Invalid target type for nested memory fragment")
            if memory_fragment_target["toolCallId"] in deferred_memory_fragments:
                raise ValueError("There should only be one memory fragment per nested tool call")
            deferred_memory_fragments[memory_fragment_target["toolCallId"]] = memory_fragment

        return own_memory_fragment, deferred_memory_fragments

    def _restore_agentic_loop_state(self, memory_fragment: MemoryFragment) -> tuple[list[FunctionToolCall], dict[str, ToolOutput]]:
        """
        Restore the agentic loop state from a memory fragment, namely:
        - iteration_number
        - generated_messages
        - all_sources

        Also extract the curated list of tool calls still pending from the memory fragment, and returns this list.
        """

        # restore iteration number
        if "agentLoopIteration" not in memory_fragment or memory_fragment["agentLoopIteration"] is None:
            raise ValueError("Missing iteration number in memory fragment")
        self.iteration_number = memory_fragment["agentLoopIteration"]

        # process stashed messages
        memory_fragment_messages = memory_fragment.get("messages")
        if not memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was empty")

        # the last message should always be a partial tool outputs message
        partial_tool_outputs_message = memory_fragment_messages.pop()
        if partial_tool_outputs_message["role"] != "tool":
            raise ValueError("Tool validation response was received but the memory fragment is missing the partial tool outputs")

        # extract pending tool calls, which should be in the last message
        if not memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was incomplete")

        raw_pending_tool_calls: list[FunctionToolCall] = memory_fragment_messages[-1].get("toolCalls")
        if not raw_pending_tool_calls:
            raise ValueError("Tool validation response was received but the memory fragment did not contain any tool calls in its last message")

        # restore state of the agentic loop before the interruption
        self.generated_messages.extend(memory_fragment_messages)

        memory_fragment_sources = memory_fragment.get("stashedSources") or []
        self.all_sources.extend(memory_fragment_sources)

        # tool calls that already have an output are not pending anymore
        partial_tool_outputs = partial_tool_outputs_message.get("toolOutputs") or []
        partial_tool_outputs_map = {tool_output["callId"]: tool_output for tool_output in partial_tool_outputs}
        if partial_tool_outputs_map:
            still_pending = []
            for ptc in raw_pending_tool_calls:
                if ptc.get("id") not in partial_tool_outputs_map.keys():
                    still_pending.append(ptc)
            raw_pending_tool_calls = still_pending

        return raw_pending_tool_calls, partial_tool_outputs_map

    def _initialise_from_tool_validations(self, current_tools: dict[str, LLMTool], parent_trace: SpanBuilder) -> tuple[Collection[ToolCall], dict[str, ToolCallValidationInfo], dict[str, ToolOutput]]:
        """
        Initialise the agentic loop so it can resume execution after an interruption for validating tool calls.
        - restore the agentic loop state to what it was before the interruption
        - collect the pending tool calls of the interrupted loop iteration so the agentic loop can start by running them
        - gather all the validation data relative to the pending tool calls

        Returns:
        - the list of pending tool calls
        - a map of the validation data relative to each pending tool call, indexed by tool call id
        """
        with parent_trace.subspan("DKU_AGENT_TOOL_CALLS_VALIDATIONS_CHECK") as validations_check_trace:
            # -------------------------------------------------
            # 1) collect validation responses
            raw_validation_responses_map = self._collect_tool_validation_responses()

            n_validation_responses = len(raw_validation_responses_map)
            n_accepted = sum(vr.get("validated") or 0 for vr in raw_validation_responses_map.values())
            n_rejected = n_validation_responses - n_accepted
            validations_check_trace.attributes["nbAccepted"] = n_accepted
            validations_check_trace.attributes["nbRejected"] = n_rejected
            # -------------------------------------------------

            # -------------------------------------------------
            # 2) collect validation requests
            all_validation_requests = self._collect_validation_requests()

            validations_check_trace.attributes["nbReceivedValidationRequests"] = len(all_validation_requests)
            # -------------------------------------------------

            # -------------------------------------------------
            # 3) collect memory fragments to
            # - extract the pending tool calls
            # - restore the state of the agentic loop before the interruption
            own_memory_fragment, deferred_memory_fragments = self._collect_and_triage_memory_fragments()

            raw_pending_tool_calls, partial_tool_outputs_map = self._restore_agentic_loop_state(own_memory_fragment)
            # -------------------------------------------------

            # -------------------------------------------------
            # 4) validate and parse pending tool calls
            # and check that the pending tool calls target tools that are still supported by this agent
            pending_tool_calls_map: dict[str, ToolCall] = {}
            for tool_call in raw_pending_tool_calls:
                parsed_tool_call = _validate_tool_call(tool_call)
                if parsed_tool_call.name not in current_tools:
                    raise ValueError(f"Entering agentic loop with tool calls that are not supported by this agent: {parsed_tool_call.name}")

                pending_tool_calls_map[parsed_tool_call.id] = parsed_tool_call
            # -------------------------------------------------

            # -------------------------------------------------
            # 5) split self-owned from nested validation requests
            own_validation_requests, deferred_validation_requests = self._triage_validation_requests(all_validation_requests)
            # -------------------------------------------------

            # -------------------------------------------------
            # 6) restore memory fragments and validation requests/responses on nested tool calls
            for tool_call in pending_tool_calls_map.values():
                if memory_fragment := deferred_memory_fragments.get(tool_call.id):
                    tool_call.memory_fragment = memory_fragment
                if tool_validation_requests := deferred_validation_requests.get(tool_call.id):
                    tool_call.tool_validation_requests = tool_validation_requests
                    tool_call.tool_validation_responses = []
                    for tvr in tool_validation_requests:
                        # we pop the validation response here to be sure to leave only own validation responses in the map at the end of this block
                        if tool_validation_response := raw_validation_responses_map.pop(tvr["id"], None):
                            tool_call.tool_validation_responses.append(tool_validation_response)
            # -------------------------------------------------

            # -------------------------------------------------
            # 7) aggregate validation infos
            validation_infos_map: dict[str, ToolCallValidationInfo] = {}

            # interrupted tool calls (nested validations)
            for tool_call_id in deferred_validation_requests.keys():
                # if this tool call has a deferred validation request, then we must validate the tool call at this agent level
                # these tool calls have already been started on the previous turn, so they either did not require validation or they have already been validated on the previous turn
                validation_infos_map[tool_call_id] = ToolCallValidationInfo(validated=True, allow_editing_inputs=False)

            # not yet started tool calls (own validations)
            own_validation_infos_map = self._aggregate_own_validation_infos(raw_validation_responses_map, own_validation_requests)
            for tool_call_id, validation_info in own_validation_infos_map.items():
                if tool_call_id in validation_infos_map:
                    raise ValueError(f"Incompatible hierarchies of the tool validation requests, tool call {tool_call_id} received a deferred validation but also received a validation for itself at the same time")
                validation_infos_map[tool_call_id] = validation_info

            # all tool calls still pending at this point should have been covered by the above
            for tool_call_id in pending_tool_calls_map.keys():
                if tool_call_id not in validation_infos_map:
                    raise ValueError(f"Pending tool call {tool_call_id} is missing validation request data")
            # -------------------------------------------------

            # -------------------------------------------------
            # 8) fix tool calls with edited inputs
            for tool_call_id, validation_info in validation_infos_map.items():
                # check that validated tool calls are all about pending tool calls
                if tool_call_id not in pending_tool_calls_map:
                    # validated tool call (request + response) but no matching pending tool call
                    raise ValueError(f"Tool validation response for tool call {tool_call_id} does not match any of the pending tool calls in the memory fragment")

                # if allowEditingInputs, update the function arguments with those from the validation response
                # we should do a proper update here, not a full replacement
                if validation_info.edited_arguments is not None:
                    pending_tool_call = pending_tool_calls_map[tool_call_id]
                    dku_tool_call = pending_tool_call.dku_tool_call
                    allow_editing_inputs = False
                    if validation_info.allow_editing_inputs:
                        tool_currently_allow_editing_inputs = current_tools[pending_tool_call.name].allow_editing_inputs
                        if tool_currently_allow_editing_inputs:
                            allow_editing_inputs = True
                            dku_tool_call["function"]["arguments"] = validation_info.edited_arguments
                            new_parsed_tool_call = _validate_tool_call(dku_tool_call)
                            new_parsed_tool_call.memory_fragment = pending_tool_call.memory_fragment
                            new_parsed_tool_call.tool_validation_requests = pending_tool_call.tool_validation_requests
                            new_parsed_tool_call.tool_validation_responses = pending_tool_call.tool_validation_responses
                            pending_tool_calls_map[tool_call_id] = new_parsed_tool_call
                    if not allow_editing_inputs:
                        original_arguments = dku_tool_call["function"]["arguments"] or "{}"
                        input_was_edited = json.loads(original_arguments, strict=False) != json.loads(validation_info.edited_arguments, strict=False)
                        if input_was_edited:
                            raise ValueError(f"Editing this tool's inputs is not allowed{' anymore' if validation_info.allow_editing_inputs else ''}")
            # -------------------------------------------------

            return pending_tool_calls_map.values(), validation_infos_map, partial_tool_outputs_map

    def expand_and_filter_initial_messages(self):
        """
        Prepare the initial messages list by:
        - filtering out irrelevant messages used for tool call validations in past turns
        - filtering out short-term memory fragments beyond the memory horizon configured on the agent
        - expanding the messages stored within the short-term memory fragments that fall within the memory horizon
        """
        memory_horizon = self.agent.config.get("shortTermMemoryHorizon")
        if not self.agent.config.get("shortTermMemoryEnabled"):
            memory_horizon = 0

        expanded_memory_fragments_count = 0

        expanded_messages = deque()
        next_message_role = None
        for m in reversed(self.initial_messages):
            if m["role"] in ["toolValidationResponses", "toolValidationRequests"]:
                # Remove all messages relative to tool call validations in past turns, as they are not processable by the underlying LLM
                pass

            elif m["role"] == "memoryFragment":
                if next_message_role == "toolValidationRequests":
                    # Skip partial memory fragments that were used for HITL, as they are irrelevant now
                    pass
                elif memory_horizon is not None and expanded_memory_fragments_count >= memory_horizon:
                    # Skip memory fragments beyond the memory horizon
                    pass
                elif not (memory_fragment := m.get("memoryFragment")):
                    raise ValueError("Memory fragment message received but the memory fragment is missing")
                else:
                    # Expand the short-term memory fragment
                    memory_fragment_messages = memory_fragment.get("messages") or []
                    expanded_messages.extendleft(reversed(memory_fragment_messages))
                    expanded_memory_fragments_count += 1

            else:
                # Keep regular messages that are part of the user-visible chat history
                expanded_messages.appendleft(m)

            next_message_role = m["role"]

        self.initial_messages = list(expanded_messages)

    def process_stream(self, trace: SpanBuilder):
        yield {"chunk": {"type": "event", "eventKind": "AGENT_GETTING_READY"}}

        llm = self.agent.project.get_llm(self.agent.config["llmId"])

        agent_id = self.agent.config.get("agentId", "")
        agent_name = self.agent.config.get("agentName", "")

        if self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses":
            logger.info("Starting agent turn with pending tool calls")

            with trace.subspan("DKU_AGENT_ITERATION") as iteration_trace:
                with iteration_trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:
                    current_tools = self.agent.load_tools(self.initial_messages, self.context, self.tools_cache)

                    # restore agentic loop state and retrieve pending tool calls and tool call validation infos
                    pending_tool_calls, validation_infos_map, partial_tool_outputs_map = self._initialise_from_tool_validations(current_tools, tools_trace)

                    logger.info(f"Resuming interrupted agent iteration: {self.iteration_number}")
                    iteration_trace.attributes["iterationNumber"] = self.iteration_number

                    # execute pending tool calls
                    tools_require_validation = yield from self._call_tools(pending_tool_calls, current_tools, tools_trace, validation_infos_map, partial_tool_outputs_map)
                    if tools_require_validation:
                        return

        self.expand_and_filter_initial_messages()

        while True:
            self.iteration_number += 1
            if self.iteration_number > self.max_loop_iterations:
                raise Exception(f"Agent exceeded max number of loop iterations ({self.max_loop_iterations})")

            with trace.subspan("DKU_AGENT_ITERATION") as iteration_trace:
                logger.info(f"Starting agent iteration: {self.iteration_number}")
                iteration_trace.attributes["iterationNumber"] = self.iteration_number

                with iteration_trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
                    yield {"chunk": {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}

                    # Loading the tools here every iteration, in preparation for conditional blocks (when the tools can change each iteration)
                    current_tools = self.agent.load_tools(self.initial_messages, self.context, self.tools_cache)

                    completion = llm.new_completion()

                    used_completion_settings = self.agent.config.get("completionSettings", {})  # ignore query settings, use config settings instead
                    completion._settings = used_completion_settings

                    completion.with_context(self.context)

                    completion.settings["tools"] = [llm_tool.llm_descriptor for llm_tool in current_tools.values()]

                    # Build the messages for the LLM call

                    system_prompt = None
                    if "systemPromptAppend" in self.agent.config and len(self.agent.config["systemPromptAppend"]) > 0:
                        system_prompt = self.agent.config["systemPromptAppend"]
                    else:
                        # TODO What do we do? The config is called "systemPromptAppend" but shouldn't it be systemPrompt ?
                        system_prompt = """You are a helpful and versatile AI assistant.

Your goal is to assist the user as accurately as possible, by determining the best course of action. You can leverage the tools at your disposal when necessary to answer the user's request.

Instructions:
- Reasoning: Analyze the request and determine a clear, step-by-step plan.
- Tools: Use the tools at your disposal to gather necessary information or perform actions.
- Answer: Provide a direct, complete, and final answer once you have sufficient information."""

                    completion.with_message(system_prompt, "system")

                    # Add history
                    completion.cq["messages"].extend(self.initial_messages)

                    # Add messages already generated during the turn
                    completion.cq["messages"].extend(self.generated_messages)

                    accumulated_tool_call_chunks = []
                    aggregated_text = ""

                    for ichunk in self._run_completion(completion, llm_trace):
                        if ichunk.text is not None:
                            aggregated_text += ichunk.text
                            yield {"chunk": {"text": ichunk.text}}

                        if ichunk.memory_fragment:
                            memory_fragment_msg: ChatMessage = {
                                "role": "memoryFragment",
                                "memoryFragment": ichunk.memory_fragment
                            }
                            self.generated_messages.append(memory_fragment_msg)
                        if ichunk.artifacts:
                            artifacts = ichunk.artifacts
                            for artifact in artifacts:
                                hierarchy: List = artifact.setdefault("hierarchy", [])
                                hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number, "agentId": agent_id, "agentName": agent_name})
                            yield {"chunk": {"artifacts": artifacts}}

                        if ichunk.sources:
                            sources = ichunk.sources
                            for source in sources:
                                hierarchy: List = source.setdefault("hierarchy", [])
                                hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number, "agentId": agent_id, "agentName": agent_name})
                            self.all_sources.extend(sources)

                        if ichunk.tool_call_chunks:
                            accumulated_tool_call_chunks.extend(ichunk.tool_call_chunks)

                # Reassemble tool calls
                dku_tool_calls = _tool_calls_from_chunks(accumulated_tool_call_chunks)

                if len(dku_tool_calls) > 0:
                    with iteration_trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:
                        tool_calls = [_validate_tool_call(tc) for tc in dku_tool_calls]

                        # Generate the assistant message with tool calls that we'll add in the history for the next run of the loop
                        tool_calls_msg: ChatMessage = {
                            "role": "assistant",
                            "toolCalls": [tc.dku_tool_call for tc in tool_calls]
                        }
                        if aggregated_text:
                            tool_calls_msg["content"] = aggregated_text
                        self.generated_messages.append(tool_calls_msg)

                        tools_require_validation = yield from self._call_tools(tool_calls, current_tools, tools_trace, None, None)
                        if tools_require_validation:
                            return

                else:
                    logger.info("No tools to call, agentic loop is done")

                    if self.agent.config.get("shortTermMemoryEnabled") and self.generated_messages:
                        yield {
                            "chunk": {
                                "type": "content",
                                "memoryFragment": {
                                    "messages": self.generated_messages,
                                }
                            }
                        }

                    # Merged context upsert
                    resulting_upsert = {}
                    logger.info("Building merged context upsert for the turn from %d context upserts collected during the turn" % len(self.all_context_upserts))
                    for upsert in self.all_context_upserts:
                        resulting_upsert.update(upsert)

                    logger.info(f"Final merged context upsert for the turn: {resulting_upsert}")

                    # Emit sources
                    yield {
                        "footer": {
                            "contextUpsert": resulting_upsert,
                            "additionalInformation": {
                                "sources": self.all_sources
                            }
                        }
                    }

                    return

    def _run_completion(self, completion, trace: SpanBuilder):
        logger.info("Calling LLM")

        logger.debug(f"About to run completion: {utils.get_completion_query_safe_for_logging(completion.cq)}")
        logger.debug(f"With settings: {completion.settings}")

        for raw_chunk in completion.execute_streamed():
            text = raw_chunk.data.get("text", "")
            tool_call_chunks = raw_chunk.data.get("toolCalls", [])
            artifacts = raw_chunk.data.get("artifacts", [])

            sources = raw_chunk.data.get("additionalInformation", {}).get("sources", [])
            memory_fragment = raw_chunk.data.get("memoryFragment")
            if isinstance(raw_chunk, DSSLLMStreamedCompletionFooter):
                trace.append_trace(raw_chunk.trace)

            yield InternalCompletionChunk(text, tool_call_chunks, artifacts, sources, memory_fragment)

        logger.info("LLM call done")

    def _call_tools(self, tool_calls: Collection[ToolCall], current_tools: Dict[str, LLMTool], parent_trace: SpanBuilder, validation_infos_map: Optional[dict[str, ToolCallValidationInfo]], tool_outputs_map: Optional[dict[str, ToolOutput]]):
        logger.info(f"I have {len(tool_calls)} tool(s) to call")
        parent_trace.attributes["nbToolCalls"] = len(tool_calls)

        agent_id = self.agent.config.get("agentId", "")
        agent_name = self.agent.config.get("agentName", "")
        agent_level_hierarchy_entry: AgentHierarchyEntry = {"type": "AGENT", "agentLoopIteration": self.iteration_number, "agentId": agent_id, "agentName": agent_name}

        own_validation_requests = []
        tool_calls_to_run = []

        if validation_infos_map is None:
            # 1) On a "normal" iteration of the agentic loop, we need to check tools that require validation and:
            # - run the tools that don't require it
            # - send validation requests for those that require it
            for tool_call in tool_calls:
                llm_tool = current_tools[tool_call.name]
                if llm_tool.require_validation:
                    display_tool_name = llm_tool.dku_tool_name
                    if llm_tool.dku_subtool_name:
                        display_tool_name += f" / {llm_tool.dku_subtool_name}"
                    tool_call_description = llm_tool.dku_tool.describe_tool_call(tool_call.arguments, llm_tool.llm_descriptor, context=self.context, subtool_name=llm_tool.dku_subtool_name)
                    if not tool_call_description:
                        tool_call_description = "Do you want to execute this tool call?"
                    validation_request: ToolValidationRequest = {
                        "id": f"validation-{tool_call.id}",
                        "hierarchy": [agent_level_hierarchy_entry],
                        "message": tool_call_description,
                        "toolRef": llm_tool.dku_tool_ref,
                        "toolName": display_tool_name,
                        "toolType": llm_tool.dku_tool_type,
                        "toolDescription": llm_tool.llm_descriptor["function"].get("description") or "",
                        "toolInputSchema": llm_tool.llm_descriptor["function"].get("parameters") or {},
                        "allowEditingInputs": llm_tool.allow_editing_inputs,
                        "toolCall": tool_call.dku_tool_call,
                    }
                    own_validation_requests.append(validation_request)
                else:
                    tool_calls_to_run.append(tool_call)

            # tool_calls_to_run only contains tool calls that don't require validation
            validation_infos_map = {}
        else:
            # 2) On the first iteration of the agentic loop right after resuming after an interruption:
            # The tool calls that did not require validation have already been run before the interruption (and the results of these runs should be in tool_outputs_map)
            # => The list of tool_calls have already been curated when restoring the agentic loop state, and here it contains only those still pending after the interruption.
            #
            # These tool calls either:
            # - did not start running yet because they required validation before the interruption, and the validation request was already sent
            # - started running before the interruption but were interrupted mid-execution by a nested tool validation request: in this case the execution already started so it was either already approved or did not require validation
            # => In both cases, validation_infos_map is expected to already contain validation responses for all of them and there is no need to request validation for them at this point.
            tool_calls_to_run = tool_calls

        # Call the tools in parallel
        if tool_outputs_map is None:
            tool_outputs_map = {}

        num_parallel_threads = int(self.dku_properties.get("dku.agents.maxParallelToolExecutions", 8))
        if num_parallel_threads < 1:
            raise Exception(f"Property dku.agents.maxParallelToolExecutions must be > 0 (received: {num_parallel_threads})")
        with ThreadPoolExecutor(num_parallel_threads) as executor:
            futures = []

            for tool_call in tool_calls_to_run:
                yield {"chunk": {"type": "event", "eventKind": "AGENT_TOOL_START", "eventData": {"toolName": tool_call.name}}}
                futures.append(executor.submit(self._call_one_tool, tool_call, current_tools, validation_infos_map))

        inner_validation_requests = []
        inner_memory_fragments_wrapped = []
        for tool_call, future in zip(tool_calls_to_run, futures):
            output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, tool_trace, context_upsert= future.result()
            logger.info("zipped, cu:%s" % context_upsert)

            # Update hierarchies
            if tool := current_tools.get(tool_call.name):
                tool_ref = tool.dku_tool_ref
                tool_name = tool.dku_tool_name
            else:
                # If the LLM made a bad tool call, just add "Unknown" into the hierarchy
                tool_ref = "Unknown"
                tool_name = "Unknown"
            tool_level_hierarchy_entry: ToolHierarchyEntry = {"type": "TOOL", "toolRef": tool_ref, "toolName": tool_name, "toolCallId": tool_call.id}

            for artifact in artifacts:
                hierarchy: List = artifact.setdefault("hierarchy", [])
                hierarchy.insert(0, tool_level_hierarchy_entry)
                hierarchy.insert(0, agent_level_hierarchy_entry)

            # Stream artifacts inline immediately
            yield {"chunk": {"type": "content", "artifacts": artifacts}}

            if tool_validation_requests:
                # tool returning validation requests may also return (partial) output_dict and sources but we must explicitly ignore them in this block
                # it is the responsibility of the tool to store partial outputs and sources in its memory fragment, to aggregate them when resuming, and to return the fully aggregated output and sources at the end
                # if we accounted for partial sources at the agent level, we'd be aggregating the same sources several times on each interruption of the tool run
                # if we accounted for the partial output_dict here, we wouldn't even know how to aggregate it on the next turn as this logic is known by the tool only
                #
                # artifacts are handled differently on purpose: they are streamed immediately, regardless of whether the tool is done or simply interrupted
                for tvr in tool_validation_requests:
                    if not tvr.get("hierarchy"):
                        tvr["hierarchy"] = []
                    tvr["hierarchy"].insert(0, tool_level_hierarchy_entry)
                    tvr["hierarchy"].insert(0, agent_level_hierarchy_entry)

                inner_validation_requests.extend(tool_validation_requests)

                if memory_fragment:
                    wrapped_memory_fragment: ChatMessage = {"role": "memoryFragment", "memoryFragment": memory_fragment, "memoryFragmentTarget": tool_level_hierarchy_entry}
                    inner_memory_fragments_wrapped.append(wrapped_memory_fragment)
            else:
                # tool run is truly over, we can now process its (full) output and sources
                # artifacts from the tool have already been streamed a few lines above

                # Add tool output as a message in the messages chain
                tool_outputs_map[tool_call.id] = {
                    "callId": tool_call.id,
                    "output": json.dumps(output_dict),
                    "parts": parts,
                }

                for source in sources:
                    hierarchy: List = source.setdefault("hierarchy", [])
                    hierarchy.insert(0, tool_level_hierarchy_entry)
                    hierarchy.insert(0, agent_level_hierarchy_entry)

                # Store sources, to emit in the streaming footer
                self.all_sources.extend(sources)

                logger.info("Saving a context upsert")
                self.all_context_upserts.append(context_upsert)

            # Append the trace from the tool to the agent's trace
            if tool_trace:
                parent_trace.append_trace(tool_trace)

        # Generate the tool message that we'll add in the history for the next run of the loop
        ordered_tool_outputs = []
        tool_call_message = self.generated_messages[-1]
        for tc in tool_call_message.get("toolCalls") or []:
            tool_call_id = tc.get("id")
            if tool_call_id in tool_outputs_map:
                ordered_tool_outputs.append(tool_outputs_map[tool_call_id])

        tool_outputs_msg: ChatMessage = {
            "role": "tool",
            "toolOutputs": ordered_tool_outputs
        }

        self.generated_messages.append(tool_outputs_msg)

        # Interrupt the agentic loop to send the validation requests if there are any
        all_validation_requests = own_validation_requests + inner_validation_requests
        if all_validation_requests:
            yield from self._send_validation_requests(all_validation_requests, inner_memory_fragments_wrapped, parent_trace)
            return True

        return False

    def _call_one_tool(self, call: ToolCall, current_tools: Dict[str, LLMTool], validation_infos_map: dict[str, ToolCallValidationInfo]):
        output_dict = {}
        parts = []
        sources = []
        artifacts = []
        context_upsert = None
        tool_validation_requests = None
        memory_fragment = None
        trace = SpanBuilder("DKU_MANAGED_TOOL_CALL")  # Create a fake trace, in case tool errors
        trace.begin(int(time.time() * 1000))

        llm_tool = None
        try:
            logger.info(f"Calling tool: {call}")

            llm_tool = current_tools[call.name]

            if llm_tool.require_validation:
                if call.id not in validation_infos_map:
                    raise Exception("Tool call needs a validation, but no validation was provided for this tool call")
                if not validation_infos_map[call.id].validated:
                    raise Exception("Tool call was rejected by user")

            tool_output = llm_tool.dku_tool.run(
                input=call.arguments,
                context=self.context,
                subtool_name=llm_tool.dku_subtool_name,
                memory_fragment=call.memory_fragment,
                tool_validation_responses=call.tool_validation_responses,
                tool_validation_requests=call.tool_validation_requests
            )

            if tool_output.get("error"):
                raise Exception(tool_output["error"])

            logger.info(f"Finished tool call: {call}")

            output_dict = tool_output.get("output", "")
            parts = tool_output.get("parts", [])
            sources = tool_output.get("sources", [])
            artifacts = tool_output.get("artifacts", [])
            tool_validation_requests = tool_output.get("toolValidationRequests")
            memory_fragment = tool_output.get("memoryFragment")
            trace = tool_output.get("trace", None)
            context_upsert = tool_output.get("contextUpsert") or tool_output.get("context_upsert")

        except Exception as e:
            logger.exception(f"Tool call failed: {call}")

            # TODO @new-agent-loop Add a flag to the tool, to say whether errors are fatal or not.
            #                      Add UI to set the flag when adding the tool to the agent.
            #                      If the flag is set, then throw here (or return the error here, and throw later, to keep the trace).

            output_dict = {"error": str(e)}

            # Try to add more info to the trace
            try:
                trace.inputs["input"] = call.arguments
                trace.outputs["error"] = str(e)
                if llm_tool is not None:
                    trace.attributes["toolId"] = llm_tool.dku_tool.tool_id
                    trace.attributes["toolProjectKey"] = llm_tool.dku_tool.project_key
                    trace.attributes["toolType"] = llm_tool.dku_tool.get_settings()._settings["type"]
            except Exception:
                logger.exception("Error getting more info about tool call failure")

            trace.end(int(time.time() * 1000))

        return output_dict, parts, sources, artifacts, tool_validation_requests, memory_fragment, trace, context_upsert

    def _send_validation_requests(self, validation_requests: list[ToolValidationRequest], nested_memory_fragments_messages: list[ChatMessage], parent_trace: SpanBuilder):
        pending_validation_trace = parent_trace.event("DKU_AGENT_PENDING_TOOL_CALLS_VALIDATIONS")
        pending_validation_trace.attributes["nbPendingValidationRequests"] = len(validation_requests)

        yield {
            "chunk": {
                "type": "content",
                "memoryFragment": {
                    "agentLoopIteration": self.iteration_number,
                    "messages": self.generated_messages + nested_memory_fragments_messages,
                    "stashedSources": self.all_sources,
                }
            }
        }
        yield {"chunk": {"type": "content", "toolValidationRequests": validation_requests}}

        logger.info("Some tool calls require human approval, agentic loop exits here")
        yield {
            "footer": {
                "finishReason": "tool_validation_requests",
            }
        }