import json
import logging
from typing import Any

import dataiku
from dataiku.llm.agent_tools import BaseAgentTool

from solutions.graph.explorer_builder import ExplorerMetadataManager
from solutions.graph.queries.explorer_data import run_llm_query
from solutions.graph.queries.params import RunLlmCypherParams


class GraphSearch(BaseAgentTool):
    def set_config(self, config, plugin_config):
        self.logger = logging.getLogger(__name__)
        self.config = config
        logging.basicConfig(
            level=logging.DEBUG if config.get("verbose_mode", False) else logging.INFO,
            format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            force=True,
        )
        db_folder_id = config.get("db_folder_id")
        self.db_folder = dataiku.Folder(db_folder_id)

        # Retrieve graph id from the folder content
        graph_id_list = ExplorerMetadataManager._get_snapshot_id_list(self.db_folder)

        assert len(graph_id_list) == 1, "There should be only one graph in the folder."
        self.graph_id = graph_id_list.pop()

        # Get llm connection
        client = dataiku.api_client()
        project = client.get_default_project()
        self.llm = project.get_llm(config.get("llm_id"))

    def get_descriptor(self, tool) -> dict[str, Any]:  # type: ignore
        return {
            "description": f"Searches a graph database with the following description. \n {self.config.get('content_description', 'no description provided')} \n Returns a JSON object",
            "inputSchema": {
                "$id": "https://dataiku.com/agents/tools/search/input",
                "title": "Input for graph search",
                "type": "object",
                "properties": {"query": {"type": "string", "description": "The query string"}},
                "required": ["query"],
            },
            "required": ["q"],
        }

    def invoke(self, input, trace) -> dict[str, Any]:  # type: ignore
        args = input["input"]
        query = args["query"]

        self.logger.info(f"Graph search query: {query}")

        result = run_llm_query(
            params=RunLlmCypherParams(graph_id=self.graph_id, query=query),
            db_folders=[self.db_folder],
            llm=self.llm,
            timeout_seconds=self.config.get("db_query_timeout_seconds", 60),
        )

        return {
            "output": json.dumps(result),
            "sources": [{"toolCallDescription": f"Graph search: {result['cypher_query']}", "items": []}],
        }
