import asyncio
import logging
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union

import dataiku
import dku_graphrag.query.query_monkey_patch_dataiku
import dku_graphrag.utils.config_instance as config_instance
import pandas as pd
from dataiku.llm.agent_tools import BaseAgentTool
from dku_graphrag.utils.graphrag_config import get_graphrag_config
from graphrag.api.query import global_search, local_search


def ensure_event_loop():
    try:
        return asyncio.get_event_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        return loop


def find_recipe(graph_data: Dict[str, Any], successor_id: str):
    """
    Finds the node of a given successor in a flow graph JSON data structure.

    :param graph_json: A JSON graph object
    :param successor: The successor for which to find the node.
    :return: The node
    """
    nodes = graph_data.get("nodes", {})

    for node_name, node in nodes.items():
        if successor_id in node.get("successors", []):
            return node

    raise KeyError(successor_id)


class GraphragSearch(BaseAgentTool):
    def set_config(self, config, plugin_config):
        self.logger = logging.getLogger(__name__)
        self.config = config
        logging.basicConfig(
            level=logging.DEBUG if config.get("verbose_mode", False) else logging.INFO,
            format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            force=True,
        )
        project = dataiku.api_client().get_default_project()
        index_folder_id = config.get("index_folder_id")
        flow = project.get_flow()
        graph = flow.get_graph()
        recipe_json = find_recipe(graph.data, index_folder_id)
        recipe_name = recipe_json.get("ref")
        recipe = project.get_recipe(recipe_name)
        recipe_settings = recipe.get_settings()
        recipe_definition = recipe_settings.get_recipe_raw_definition()
        custom_config = recipe_definition.get("params").get("customConfig")

        self.search_type = config.get("search_type", "local")
        self.default_community_level = config.get("default_community_level", 0)
        self.response_type = config.get("response_type", "multiple paragraphs")

        output_folder = dataiku.Folder(index_folder_id)

        if dataiku.base.remoterun.is_running_remotely() or output_folder.get_info().get("type", None) != "Filesystem":
            # Download remote files
            self.folder_path = Path(tempfile.mkdtemp())
            for file_name in output_folder.list_paths_in_partition():
                if file_name.startswith("/input") or file_name.startswith("/prompts"):
                    continue
                local_file_path = self.folder_path / file_name.lstrip('/')
                if not os.path.exists(os.path.dirname(local_file_path)):
                    os.makedirs(os.path.dirname(local_file_path))
                with output_folder.get_download_stream(file_name) as f_remote, open(local_file_path, 'wb') as f_local:
                    shutil.copyfileobj(f_remote, f_local)
        else:
            self.folder_path = Path(output_folder.get_path())

        self.graphrag_config = get_graphrag_config(custom_config, self.folder_path)
        config_instance.config_value = custom_config

        if "db_uri" in self.graphrag_config.embeddings.vector_store:
            # Update the 'db_uri' value
            self.graphrag_config.embeddings.vector_store["db_uri"] = str(
                self.folder_path / self.graphrag_config.embeddings.vector_store["db_uri"]
            )
            self.logger.debug(f"****** db_uri: {self.graphrag_config.embeddings.vector_store['db_uri']}")
        else:
            self.logger.warning("Error: 'db_uri' key not found in vector_store")
        index_output_folder = self.folder_path / Path(self.graphrag_config.storage.base_dir)
        self.nodes = pd.read_parquet(index_output_folder / "create_final_nodes.parquet")
        self.entities = pd.read_parquet(index_output_folder / "create_final_entities.parquet")
        self.communities = pd.read_parquet(index_output_folder / "create_final_communities.parquet")
        self.community_reports = pd.read_parquet(index_output_folder / "create_final_community_reports.parquet")
        self.text_units = None
        self.relationships = None
        self.covariates_df = None
        if self.search_type == "local":
            self.text_units = pd.read_parquet(index_output_folder / "create_final_text_units.parquet")
            self.relationships = pd.read_parquet(index_output_folder / "create_final_relationships.parquet")
            covariates_path = index_output_folder / "create_final_covariates.parquet"
            # covariates are optional
            if covariates_path.exists():
                self.covariates_df = pd.read_parquet(covariates_path)
        self.logger.info(
            f"Tool config initialized with: search_type={self.search_type}, folder_path ={self.folder_path}, default_community_level={self.default_community_level},  response_type ={self.response_type}"
        )

    def get_descriptor(self, tool):
        return {
            "description": f"Searches a knowledge bank with the following description:\n{self.config.get('content_description', 'no description provided')}\nReturns a text result",
            "inputSchema": {
                "$id": "https://dataiku.com/agents/tools/search/input",
                "title": "Input for the search tool",
                "type": "object",
                "properties": {"query": {"type": "string", "description": "The query string"}},
                "required": ["query"],
            },
            "required": ["q"],
        }

    async def search(self, query: str):
        self.logger.info(f"Tool search: search_type={self.search_type}, query ={query}")
        if self.search_type == "global":
            response, context = await global_search(
                config=self.graphrag_config,
                nodes=self.nodes,
                entities=self.entities,
                communities=self.communities,
                community_reports=self.community_reports,
                community_level=None,
                dynamic_community_selection=False,
                response_type=self.response_type,
                query=query,
            )
        else:
            response, context = await local_search(
                config=self.graphrag_config,
                nodes=self.nodes,
                entities=self.entities,
                community_reports=self.community_reports,
                text_units=self.text_units,
                relationships=self.relationships,
                covariates=self.covariates_df,
                community_level=self.default_community_level,
                response_type=self.response_type,
                query=query,
            )
        return response, context

    def invoke(self, input, trace):
        args = input["input"]
        query = args["query"]

        self.logger.info(f"search query: {query}")
        source_items = []

        loop = ensure_event_loop()

        response, context = loop.run_until_complete(self.search(query))

        self.logger.info(f" ==== context: {context}")

        for item in context.get("sources"):
            source_item = {
                "type": "SIMPLE_DOCUMENT",
                "url": "",
                "title": "",
                "textSnippet": item["text"],
            }
            source_items.append(source_item)

        return {
            "output": response,
            "sources": [{"toolCallDescription": f"Performed  Search for: {query}", "items": source_items}],
        }
