# New LangGraph-free implementation of the Agentic loop
# 
# By switching to explicit imperative code that we control, rather than have it mediated through a LangGraph
# graph, we can better control the execution Flow, notably to perform various kinds of dynamic tool selection,
# early-abort, etc.
#

import logging, json, copy, time, hashlib
from dataclasses import dataclass

import dataiku
from dataiku.llm.python import BaseLLM
from dataikuapi.dss.agent_tool import DSSAgentTool
from dataikuapi.dss.llm_tracing import SpanBuilder

from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict

from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter
from dataikuapi.dss.langchain.llm import _parse_tool_call_chunks


logger = logging.getLogger("dku.agents.tools_using_2")


class InternalCompletionChunk(object):
    def __init__(self, text, tool_call_chunks, artifacts):
        self.text = text
        self.tool_call_chunks = tool_call_chunks
        self.artifacts = artifacts


class PreparedTool(object):
    """
    This class represents a tool as defined by Dataiku (maps to a Dataiku object, can contain subtools)
    self.llm_tools represents the subtools as defined by an LLM
    """

    def __init__(self, tool: DSSAgentTool, tool_config: Dict):
        self.descriptor = tool.get_descriptor()
        additional_description = tool_config.get("additionalDescription", "")
        tool_config_hash = PreparedTool.get_tool_hash(tool_config)

        # The descriptor of the tool(s), in a form that completion.settings["tools"] takes
        self.llm_tools: List[LLMTool] = []

        if self.descriptor.get("multiple", False):
            for subtool_descriptor in self.descriptor.get("subtools", []):
                llm_tool_name = self.descriptor["name"] + '__' + tool_config_hash + "__" + subtool_descriptor["name"]
                self.llm_tools.append(LLMTool(
                    llm_tool_name=llm_tool_name,
                    llm_descriptor={
                        "type": "function",
                        "function": {
                            "name": llm_tool_name,
                            "description": (subtool_descriptor.get("description", "") + "\n\n" + additional_description).strip(),
                            "parameters": subtool_descriptor["inputSchema"]
                        }
                    },
                    dku_tool=tool,
                    dku_subtool_name=subtool_descriptor["name"]
                ))

        else:
            llm_tool_name = self.descriptor["name"] + '__' + tool_config_hash
            self.llm_tools.append(LLMTool(
                llm_tool_name=llm_tool_name,
                llm_descriptor={
                    "type": "function",
                    "function": {
                        "name": llm_tool_name,
                        "description": (self.descriptor.get("description", "") + "\n\n" + additional_description).strip(),
                        "parameters": self.descriptor["inputSchema"],
                    }
                },
                dku_tool=tool,
                dku_subtool_name=None
            ))

    @staticmethod
    def get_tool_hash(tool_config: Dict):
        to_hash = json.dumps(tool_config, sort_keys=True, ensure_ascii=True).encode("utf-8")
        return hashlib.sha1(to_hash).hexdigest()[:6]


@dataclass
class LLMTool:
    """
    This class represents a tool as defined by an LLM
    """
    llm_tool_name: str
    llm_descriptor: Dict
    dku_tool: DSSAgentTool
    dku_subtool_name: Optional[str]


# Helper function to reassemble several tool calls from a LLM Mesh list of tool call chunks
def _tool_calls_from_chunks(chunks):
    if len(chunks) == 0:
        return []

    # Combine chunks into a list of tool calls
    tool_calls_list = []
    tool_calls_map = {}

    for chunk_tool_call in chunks:
        if "index" not in chunk_tool_call or chunk_tool_call["index"] is None:
            # tool call does not have an index, we won't be able to aggregate chunks, so assume it's full
            tool_calls_list.append(chunk_tool_call)

        else:
            index = chunk_tool_call["index"]

            if index not in tool_calls_map:
                # tool call not yet in map: insert it
                tool_calls_map[index] = chunk_tool_call

            else:
                # tool call already in map: update it
                tool_call = tool_calls_map[index]

                if ("id" not in tool_call or tool_call["id"] is None) and ("id" in chunk_tool_call and chunk_tool_call["id"] is not None):
                    tool_call["id"] = chunk_tool_call["id"]
                if ("name" not in tool_call or tool_call["name"] is None) and ("name" in chunk_tool_call and chunk_tool_call["name"] is not None):
                    tool_call["name"] = chunk_tool_call["name"]
                if "args" in chunk_tool_call and chunk_tool_call["args"]:
                    if "args" not in tool_call or tool_call["args"] is None:
                        tool_call["args"] = ""
                    tool_call["args"] += chunk_tool_call["args"]

    tool_calls_list.extend(tool_calls_map.values())

    # Check the tools calls are valid
    for tool_call in tool_calls_list:
        if "id" not in tool_call or not tool_call["id"]:
            raise Exception(f"Tool call missing id: {tool_call}")
        if "name" not in tool_call or not tool_call["name"]:
            raise Exception(f"Tool call missing name: {tool_call}")
        if "args" not in tool_call or not tool_call["args"]:
            tool_call["args"] = "{}"
        try:
            tool_call["args"] = json.loads(tool_call["args"], strict=False)
        except json.JSONDecodeError:
            raise Exception(f"Invalid tool call arguments format: {tool_call}")
        if "index" in tool_call:
            del tool_call["index"]

    return tool_calls_list


class ToolsUsingAgent(BaseLLM):
    def __init__(self):
        super().__init__()
        self.project = dataiku.api_client().get_default_project()

    def set_config(self, config, unused):
        self.config = config

    def load_tools(self, messages, context, tools_cache: Dict[str, PreparedTool]) -> Dict[str, LLMTool]:
        llm_tools = {}

        for used_tool in self.config["tools"]:
            tool_ref = used_tool.get("toolRef", False)
            if not tool_ref:
                logger.warning(f"Ignoring tool with empty toolRef: {used_tool}")
                continue

            logger.info(f"Will use tool {tool_ref}")

            tool_hash = PreparedTool.get_tool_hash(used_tool)
            if tool_hash in tools_cache:
                prepared_tool = tools_cache[tool_hash]
            else:
                dku_api_tool = self.project.get_agent_tool(tool_ref)
                prepared_tool = PreparedTool(dku_api_tool, used_tool)
                tools_cache[tool_hash] = prepared_tool

            for llm_tool in prepared_tool.llm_tools:
                llm_tools[llm_tool.llm_tool_name] = llm_tool

        return llm_tools

    def process_stream(self, query, settings, trace: SpanBuilder):
        # The ToolsUsingAgent class must serve multiple requests at the same time. Thus, we cannot store
        # anything on it. We therefore use a class representing a single request/turn (even if multiple
        # tool calling iterations happen during the turn).

        # NB: We intentionally ignore the query settings here, because we use settings from the agent instead

        turn = ToolsUsingAgentTurn(self, query)

        for c in turn.process_stream(trace):
            yield c


class ToolsUsingAgentTurn(object):
    """
    This class holds everything that happens during a single turn of the agent (i.e. a single call to the agent through the LLM Mesh API).

    The turn may itself represent several iterations of the main react loop, as represented by the big "while True" in the main method
    """

    def __init__(self, agent: ToolsUsingAgent, query):
        self.agent = agent
        self.initial_messages = query["messages"]
        self.initial_context = query.get("context", {})

        self.context = copy.deepcopy(self.initial_context)

        self.dku_properties = { dku_property["name"]: dku_property["value"] for dku_property in agent.config["dkuProperties"] }

        # For later use :)
        # self.context_upsert =

        # Intermediate messages that were created during processing of this turn (mostly, these are tool_call / tool_output messages)
        self.generated_messages = []

    def process_stream(self, trace: SpanBuilder):
        yield {"chunk": {"type": "event", "eventKind": "AGENT_GETTING_READY"}}

        llm = self.agent.project.get_llm(self.agent.config["llmId"])
        all_sources = []
        tools_cache = {}

        # TODO @human-in-the-loop validated tool calls
        # if we have some validated tool calls
        #    add the tool call(s) to the messages
        #    call the tools
        #    add the results to the messages

        max_loop_iterations = int(self.dku_properties.get("dku.agents.maxLoopIterations", 15))
        iteration_number = 0
        while True:
            iteration_number += 1
            if iteration_number > max_loop_iterations:
                raise Exception(f"Agent exceeded max number of loop iterations ({max_loop_iterations})")

            with trace.subspan("DKU_AGENT_ITERATION") as iteration_trace:
                logger.info(f"Starting agent iteration: {iteration_number}")

                with iteration_trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
                    yield {"chunk": {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}

                    # Loading the tools here every iteration, in preparation for conditional blocks (when the tools can change each iteration)
                    current_tools = self.agent.load_tools(self.initial_messages, self.context, tools_cache)

                    completion = llm.new_completion()

                    used_completion_settings = self.agent.config.get("completionSettings", {})  # ignore query settings, use config settings instead
                    completion._settings = used_completion_settings

                    completion.with_context(self.context)

                    completion.settings["tools"] = [llm_tool.llm_descriptor for llm_tool in current_tools.values()]

                    # Build the messages for the LLM call

                    system_prompt = None
                    if "systemPromptAppend" in self.agent.config and len(self.agent.config["systemPromptAppend"]) > 0:
                        system_prompt = self.agent.config["systemPromptAppend"]
                    else:
                        # TODO What do we do? The config is called "systemPromptAppend" but shouldn't it be systemPrompt ?
                        system_prompt = """You are a helpful and versatile AI assistant.

Your goal is to assist the user as accurately as possible, by determining the best course of action. You can leverage the tools at your disposal when necessary to answer the user's request.

Instructions:
- Reasoning: Analyze the request and determine a clear, step-by-step plan.
- Tools: Use the tools at your disposal to gather necessary information or perform actions.
- Answer: Provide a direct, complete, and final answer once you have sufficient information."""

                    completion.with_message(system_prompt, "system")

                    # Add history
                    completion.cq["messages"].extend(self.initial_messages)

                    # Add messages already generated during the turn
                    completion.cq["messages"].extend(self.generated_messages)

                    accumulated_tool_call_chunks = []
                    for ichunk in self._run_completion(completion, llm_trace):

                        if ichunk.text is not None:
                            yield {"chunk": {"text": ichunk.text}}

                        if ichunk.artifacts:
                            yield {"chunk": {"artifacts": ichunk.artifacts}}

                        if ichunk.tool_call_chunks:
                            accumulated_tool_call_chunks.extend(ichunk.tool_call_chunks)

                # Reassemble tool calls
                tool_calls = _tool_calls_from_chunks(accumulated_tool_call_chunks)

                if len(tool_calls) > 0:
                    logger.info(f"I have {len(tool_calls)} tool(s) to call")

                    with iteration_trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:

                        # Generate the assistant message with tool calls that we'll add in the history for the next run of the loop
                        tool_calls_msg = {
                            "role": "assistant",
                            "toolCalls": []
                        }
                        for i, tool_call in enumerate(tool_calls):
                            tool_calls_msg["toolCalls"].append({
                                "id": tool_call["id"],
                                "index": i,
                                "type": "function",
                                "function": {
                                    "name": tool_call["name"],
                                    "arguments": json.dumps(tool_call["args"])
                                }
                            })

                        self.generated_messages.append(tool_calls_msg)

                        # Actually call the tools and generate the tool message that we'll add in the history for the next run of the loop
                        tool_outputs_msg = {
                            "role": "tool",
                            "toolOutputs": []
                        }

                        # Call the tools in parallel
                        num_parallel_threads = int(self.dku_properties.get("dku.agents.maxParallelToolExecutions", 8))
                        if num_parallel_threads < 1:
                            raise Exception(f"Property dku.agents.maxParallelToolExecutions must be > 0 (received: {num_parallel_threads})")
                        with ThreadPoolExecutor(num_parallel_threads) as executor:
                            futures = []

                            for tool_call in tool_calls:
                                yield {"chunk": {"type": "event", "eventKind": "AGENT_TOOL_START", "eventData": {"toolName": tool_call["name"]}}}

                                def thread_function(tool_call):
                                    (output_dict, sources, artifacts, tool_trace) = self._call_one_tool(current_tools, tool_call, self.context)
                                    return tool_call, output_dict, sources, artifacts, tool_trace

                                futures.append(executor.submit(thread_function, tool_call))

                        for future in futures:
                            tool_call, output_dict, sources, artifacts, tool_trace = future.result()

                            # Add tool output as a message in the messages chain
                            tool_outputs_msg["toolOutputs"].append({
                                "callId": tool_call["id"],
                                "output": json.dumps(output_dict)
                            })

                            # Store sources, to emit in the streaming footer
                            all_sources.extend(sources)

                            # Stream artifacts inline
                            yield {"chunk": {"type": "content", "artifacts": artifacts}}

                            # Append the trace from the tool to the agent's trace
                            if tool_trace:
                                tools_trace.append_trace(tool_trace)

                        self.generated_messages.append(tool_outputs_msg)

                else:
                    logger.info("No tools to call, agentic loop is done")

                    # Emit sources
                    yield {
                        "footer": {
                            "additionalInformation": {
                                "sources": all_sources
                            }
                        }
                    }

                    break

    def _run_completion(self, completion, trace: SpanBuilder):
        logger.info("Calling LLM")

        logger.debug(f"About to run completion: {completion.cq}")
        logger.debug(f"With settings: {completion.settings}")

        for raw_chunk in completion.execute_streamed():
            text = raw_chunk.data.get("text", "")

            tool_call_chunks = []
            raw_tool_calls = raw_chunk.data.get("toolCalls")
            if raw_tool_calls:
                tool_call_chunks = _parse_tool_call_chunks(raw_tool_calls)

            artifacts = raw_chunk.data.get("artifacts", [])

            if isinstance(raw_chunk, DSSLLMStreamedCompletionFooter):
                trace.append_trace(raw_chunk.trace)

            # TODO : If we have sources from the LLM, then return these in the footer (can have these when the LLM Mesh query tool returns them (sc-270692))

            yield InternalCompletionChunk(text, tool_call_chunks, artifacts)

        logger.info("LLM call done")

    def _call_one_tool(self, current_tools: Dict[str, LLMTool], call, context):
        # call is of the form returned by _tool_calls_from_chunk

        output_dict = {}
        sources = []
        artifacts = []
        trace = SpanBuilder("DKU_MANAGED_TOOL_CALL")  # Create a fake trace, in case tool errors
        trace.begin(int(time.time() * 1000))

        llm_tool = None
        try:
            logger.info(f"Calling tool: {call}")

            llm_tool = current_tools[call["name"]]

            tool_output = llm_tool.dku_tool.run(call["args"], context, subtool_name=llm_tool.dku_subtool_name)

            if tool_output.get("error"):
                raise Exception(tool_output["error"])

            logger.info(f"Finished tool call: {call}")

            output_dict = tool_output["output"]
            sources = tool_output.get("sources", [])
            artifacts = tool_output.get("artifacts", [])
            trace = tool_output.get("trace", None)

        except Exception as e:
            logger.exception(f"Tool call failed: {call}")

            # TODO @new-agent-loop Add a flag to the tool, to say whether errors are fatal or not.
            #                      Add UI to set the flag when adding the tool to the agent.
            #                      If the flag is set, then throw here (or return the error here, and throw later, to keep the trace).

            output_dict = {"error": str(e)}

            # Try to add more info to the trace
            try:
                trace.inputs["input"] = call.get("args", {})
                trace.outputs["error"] = str(e)
                if llm_tool is not None:
                    trace.attributes["toolId"] = llm_tool.dku_tool.tool_id
                    trace.attributes["toolProjectKey"] = llm_tool.dku_tool.project_key
                    trace.attributes["toolType"] = llm_tool.dku_tool.get_settings()._settings["type"]
            except Exception as e:
                logger.exception("Error getting more info about tool call failure")

            trace.end(int(time.time() * 1000))

        return output_dict, sources, artifacts, trace
