import logging
import asyncio
import traceback
import copy
from concurrent.futures import ThreadPoolExecutor

from dataiku.base.async_link import AsyncJavaLink
from dataiku.llm.python.server import get_redacted_start_command
from dataiku.llm.tracing import new_trace
from dataiku.core import debugging

from dataiku.base.utils import watch_stdin, get_clazz_in_code
from dataiku.base.socket_block_link import parse_javalink_args
from .base import BaseAgentTool


logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s', force=True)
logger = logging.getLogger(__name__)


def get_redacted_command(command):

    if "input" in command and "context" in command["input"]:
        # The context has already been printed nicely and redacted by the Java part.
        # Here, just redact everything
        redacted_command = copy.deepcopy(command)
        redacted_command["input"]["context"] = { k : "REDACTED" for k in redacted_command["input"]["context"].keys() }
        return redacted_command

    elif "config" in command:
        redacted_command = get_redacted_start_command(command)
        config = redacted_command["config"]

        # local MCP tools
        if password_params := command.get("config", {}).get("dkuPasswordParams"):
            for section in ["env", "dkuProperties"]:
                for key in password_params:
                    if key in config[section]:
                        config[section][key] = "**redacted**"
            del redacted_command["config"]["dkuPasswordParams"]
        return redacted_command
    else:
        return command

class PythonAgentToolServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(32)
        self.run_counter = 1

    async def handler(self, command):
        logger.info("PythonAgentToolServer handler received command: %s" % get_redacted_command(command))

        if command["type"] == "start":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "invoke":
            logger.info("\n===============  Start running tool - run %s ===============", self.run_counter)
            try:
                yield await asyncio.get_running_loop().run_in_executor(
                    self.executor, self.invoke, command
                )
                logger.info("\n=============== End running tool - run %s ===============", self.run_counter)
            finally:
                self.run_counter += 1

        elif command["type"] == "get-descriptor":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.get_descriptor, command
            )

        elif command["type"] == "describe-tool-call":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.describe_tool_call, command
            )

        elif command["type"] == "load-sample-query":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.load_sample_query, command
            )
        else:
            raise Exception("Unknown command type: %s" % command["type"])


    def start(self, start_command):
        assert not self.started, "Already started"

        if "code" in start_command and start_command["code"] is not None:
            code = start_command.get("code", {})
            clazz = get_clazz_in_code(code, BaseAgentTool, strict_module=True)
            self.instance = clazz()
        elif "pyClazz" in start_command and start_command["pyClazz"] is not None:
            py_clazz = start_command["pyClazz"]
            if py_clazz == "dataiku.llm.agent_tools.mcp.generic_stdio.GenericStdioMCPClient":
                from .mcp.generic_stdio import GenericStdioMCPClient
                self.instance = GenericStdioMCPClient()
            else:
                raise NotImplementedError(f"Unknown pyClazz for agent tools server: {py_clazz}")
        else:
            raise Exception("not implemented")

        if hasattr(self.instance, "set_config"):
            self.instance.set_config(start_command.get("config", {}), start_command.get("pluginConfig", {}))

        self.started = True
        return { "ok" : True }

    def invoke(self, process_command):
        assert self.started, "Not started"  
        input = process_command.get("input")

        trace = new_trace("PYTHON_AGENT_TOOL_CALL")
        trace.__enter__()
        trace.attributes["class"] = type(self.instance).__name__

        try:
            logger.info("Invoking tool")
            result = self.instance.invoke(input, trace)
            logger.info("Got Tool result")
        except Exception as e:
            logger.exception("Tool failed")
            result = {"error": type(e).__name__ + ": " + str(e)}

        trace.__exit__(None, None, None)

        result["trace"] = trace.to_dict()

        return result

    def get_descriptor(self, process_command):
        assert self.started, "Not started"  
        tool = process_command.get("tool")

        ret = self.instance.get_descriptor(tool)

        logger.info("Got Tool descriptor: %s" % ret)

        return ret

    def describe_tool_call(self, process_command):
        assert self.started, "Not started"
        tool = process_command.get("tool")
        descriptor = process_command.get("descriptor")
        input = process_command.get("input")

        tool_call_description = self.instance.describe_tool_call(tool, descriptor, input)

        logger.info("Got tool call description: %s" % tool_call_description)

        return {"description": tool_call_description}

    def load_sample_query(self, process_command):
        assert self.started, "Not started"
        tool = process_command.get("tool")

        ret = self.instance.load_sample_query(tool)

        logging.info("Got sample query: %s" % ret)

        return ret

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PythonAgentToolServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)