from langchain_core.messages import HumanMessage, ToolMessage

import dataiku
from dataiku.llm.python import BaseLLM


class MyLLM(BaseLLM):
    def process(self, query, settings, trace):
        prompt = query["messages"][-1]["content"]

        project = dataiku.api_client().get_default_project()

        def find_tool(name: str) -> object:
            tools = project.list_agent_tools()
            tool = [tool for tool in tools if tool.get("name") == name]
            if tool:
                return project.get_agent_tool(tool[0].get('id'))
            return None

        # If you know the tool IDs, you can use them directly.
        get_customer = find_tool("Get Customer Info").as_langchain_structured_tool()
        get_company = find_tool("Get Company Info").as_langchain_structured_tool()

        tools = [get_customer, get_company]

        llm = project.get_llm("openai:YOUR_OPENAI_CONNECTION_NAME:gpt-4o-mini").as_langchain_chat_model()
        llm_with_tools = llm.bind_tools(tools)

        messages = [HumanMessage(prompt)]
        ai_msg = llm_with_tools.invoke(messages)

        tool_messages = []
        for tool_call in ai_msg.tool_calls:
            if tool_call["name"] == get_customer.name:
                tool_output = get_customer(tool_call["args"])
            else:
                tool_output = get_company(tool_call["args"])
            tool_messages.append(ToolMessage(tool_call_id=tool_call["id"], content=tool_output))

        messages = [
            HumanMessage(prompt),
            ai_msg,
            *tool_messages
        ]

        final_resp = llm_with_tools.invoke(messages)
        return {"text": final_resp.content}
