import asyncio
import copy
import logging
import traceback
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import cast, Type

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin, get_clazz_in_code
from dataiku.core import debugging
from dataiku.core.import_alias import register_python_import_alias
from dataiku.llm.python import BaseLLM
from dataiku.llm.python.types import StartAgentServerCommand
from dataiku.llm.python.processing.completion_processor import CompletionProcessor

logger = logging.getLogger(__name__)


def get_redacted_start_command(command):
    if "config" not in command:
        return command

    def plugin_get_redacted_start_command(command):
        assert "config" in command
        redacted_command = copy.deepcopy(command)
    
        def redact_data(data: dict, fields_to_redact: list[str]) -> None:
            """ Recursively redacts specified fields in a nested dictionary or list by modifying the data structure *in place*.
        
            Args:
                data: The dictionary or list to process. This will be mutated.
                fields_to_redact: A list of field names (keys) to redact.
            """
            if isinstance(data, dict):
                for key, value in data.items():
                    if key in fields_to_redact:
                        data[key] = "**redacted**"
                    else:
                        redact_data(value, fields_to_redact)
            elif isinstance(data, list):
                for item in data:
                    redact_data(item, fields_to_redact)
    
        if password_params := command.get("pluginConfig", {}).get("dkuPasswordParams"):
            redact_data(redacted_command["config"], password_params)
            redact_data(redacted_command["pluginConfig"], password_params)
            del redacted_command["pluginConfig"]["dkuPasswordParams"]
    
        return redacted_command

    return plugin_get_redacted_start_command(command)


class PythonLLMServer:
    processor: CompletionProcessor
    started: bool
    executor: ThreadPoolExecutor
    run_counter: int

    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(32)
        self.run_counter = 1

    async def handler(self, command: StartAgentServerCommand):

        if command["type"] == "start-agent-server":
            logger.info("PythonLLMServer handler received start command: %s", get_redacted_start_command(command))
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )

        elif command["type"] == "process-completion-query":
            logger.debug("PythonLLMServer handler received completion query command")
            logger.info("\n===============  Start completion query - run %s ===============", self.run_counter)
            try:
                if command["stream"]:
                    async for resp in self.processor.process_query_stream(command):
                        yield resp
                else:
                    yield await self.processor.process_query(command)
                logger.info("\n=============== End completion query - run %s ===============", self.run_counter)
            finally:
                self.run_counter+=1
        else:
            raise Exception("Unknown command type: %s" % command["type"])

    def start(self, start_command: StartAgentServerCommand):
        assert not self.started, "Already started"
        if "code" in start_command and start_command["code"] is not None:
            code = start_command.get("code", "")
            if should_patch_langchain_legacy_imports(code):
                register_python_import_alias("langchain.", "langchain_classic.")
            clazz = cast(Type[BaseLLM], get_clazz_in_code(code, BaseLLM, strict_module=True))
        else:
            py_clazz = start_command["pyClazz"]
            if py_clazz == "dataiku.llm.python.tools_using.ToolUsingAgent":
                from .tools_using import ToolsUsingAgent
                clazz = ToolsUsingAgent
            elif py_clazz == "dataiku.llm.python.tools_using_2.ToolUsingAgent":
                from .tools_using_2 import ToolsUsingAgent
                clazz = ToolsUsingAgent
            elif py_clazz == "dataiku.llm.python.blocks_graph.BlocksGraphAgent":
                from .blocks_graph.agent import BlocksGraphAgent
                clazz = BlocksGraphAgent
            else:
                raise Exception("Missing BaseLLM implementation")

        self.processor = CompletionProcessor(clazz, self.executor, start_command.get("config", {}), start_command.get("pluginConfig", {}), trace_name="DKU_AGENT_CALL")
        self.started = True
        return { "ok" : True }

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logger.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


def should_patch_langchain_legacy_imports(code):
    """ Only trigger hacky legacy module aliasing for langchain>=1.0 on the internal RAG code env
    """
    if "INTERNAL_retrieval_augmented_generation" not in sys.executable:
        return False
    try:
        import langchain
    except ModuleNotFoundError:
        return False
    return langchain.__version__[0] != '0'


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_running_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PythonLLMServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
