import asyncio
import logging
import os
import sys
import threading
from concurrent.futures import Future

from fastmcp import Client
from fastmcp.client.transports import StdioTransport

from dataiku.llm.agent_tools.mcp.config import MCPClientConfig


logger = logging.getLogger(__name__)


class FastMCPClient:
    def __init__(self, config):
        self.tool_config = config
        self.client_config = MCPClientConfig(self.tool_config)
        self.loop = None
        self.thread = None
        self._start()

    def _start(self):
        """ Run a background thread with an asyncio event loop that runs MCP commands as async tasks
        """
        _ready_future = Future()

        def _run_loop(loop):
            async def ready_check(future):
                """ Check that the client can connect to the server then become ready to accept commands
                """
                try:
                    async with self.raw_client:
                        await self.raw_client.ping()
                        future.set_result(True)
                except Exception as exc:
                    logger.exception("MCP client initialization failed", exc_info=exc)
                    future.set_exception(exc)

            async def watchdog(interval: int):
                """ Routinely check that the server connection is still alive, restart one if needed
                """
                while True:
                    try:
                        await asyncio.sleep(interval)
                        async with self.raw_client:
                            pass
                    except RuntimeError as exc:
                        logger.exception("MCP client healthcheck failed, restarting it", exc_info=exc)
                        self._rebuild_raw_client()
                    except Exception as exc:
                        logger.exception("MCP client healthcheck failed", exc_info=exc)

            try:
                self._rebuild_raw_client()
            except Exception as exc:
                logger.exception("MCP client initialization failed", exc_info=exc)
                _ready_future.set_exception(exc)
                return

            asyncio.set_event_loop(loop)
            asyncio.run_coroutine_threadsafe(ready_check(_ready_future), loop)
            asyncio.run_coroutine_threadsafe(watchdog(self.client_config.watchdog_interval), loop)
            loop.run_forever()

        self.loop = asyncio.new_event_loop()
        self.thread = threading.Thread(name="fastmcp-client-asyncio-loop", target=_run_loop, args=(self.loop,), daemon=True)
        logger.info("MCP Client starting loop")
        self.thread.start()
        logger.info("MCP Client loop started")
        _ready_future.result()
        logger.info("MCP Client initialization done")

    def _rebuild_raw_client(self):
        """ Build or rebuild the raw FastMCP transport and client
        """
        if self.tool_config["command"] != "python":
            command = self.tool_config["command"]
            default_cwd = None
        else:
            command = sys.executable
            # default to using the code env resources directory as current working dir if it exists, this lets users run the server using a relative path
            resources_base_dir = os.getenv("DKU_CODE_ENV_RESOURCES_PATH")
            if os.path.isdir(resources_base_dir):
                default_cwd = resources_base_dir
            else:
                default_cwd = None

        transport = StdioTransport(
            command=command,
            args=self.tool_config["args"],
            env=self.tool_config["env"],
            cwd=self.client_config.cwd or default_cwd,
        )

        logger.info(f"Creating MCP Client with transport: {transport}")
        kwargs = {
            "init_timeout": self.client_config.init_timeout,
            "roots": self.client_config.roots,
        }
        if timeout := self.client_config.request_timeout:
            kwargs["timeout"] = timeout
        self.raw_client = Client(
            transport,
            **kwargs,
        )

    def list_tools(self):
        async def _list_tools(client):
            async with client:
                return await client.list_tools()

        future = asyncio.run_coroutine_threadsafe(_list_tools(self.raw_client), self.loop)
        return future.result()

    def call_tool(self, name, arguments):
        enabled = self.tool_config["subtoolsEnabledByDefault"] or self.tool_config["subtoolsStateOverride"].get(name, False)
        if not enabled:
            raise ValueError(f"subtool '{name}' is invalid or disabled")

        async def _call_tool(client, name, arguments):
            async with client:
                return await client.call_tool(name, arguments)

        future = asyncio.run_coroutine_threadsafe(_call_tool(self.raw_client, name, arguments), self.loop)
        return future.result()
