import logging
import asyncio
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin, get_clazz_in_code
from dataiku.core import debugging
from dataiku.llm.tracing import new_trace
from .base import BaseGuardrail


class PythonGuardrailServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(32)

    async def handler(self, command):
        logging.info("PythonGuardrailServer handler received command: %s" % command)
        if command["type"] == "start":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "process":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.process, command
            )
        else:
            raise Exception("Unknown command type: %s" % command["type"])


    def start(self, start_command):
        assert not self.started, "Already started"

        project_key = start_command.get("projectKey", {})

        if "code" in start_command and start_command["code"] is not None:
            code = start_command.get("code", {})
            clazz = get_clazz_in_code(code, BaseGuardrail, strict_module=True)
            self.instance = clazz()
        else:
            raise Exception("not implemented")

        if hasattr(self.instance, "set_config"):
            self.instance.set_config(start_command.get("config", {}), start_command.get("pluginConfig", {}))

        self.started = True

    def process(self, process_command):
        assert self.started, "Not started"  
        input = process_command.get("input")

        trace = new_trace("PYTHON_GUARDRAIL_CALL")
        trace.__enter__()

        logging.info("Processing on guardrail")


        import dataikuapi.dss.llm

        dataikuapi.dss.llm._dku_bypass_guardrail_ls.current_bypass_token = input["bypassToken"]

        try:

            result = self.instance.process(input, trace)
        finally:
            del dataikuapi.dss.llm._dku_bypass_guardrail_ls.current_bypass_token

        logging.info("Guardrail done processing")

        trace.__exit__(None, None, None)

        if "completionQuery" in input and not "completionResponse" in input and not "queryGuardrailResponse" in result:
            result["queryGuardrailResponse"] = {"action": "PASS"}
        if "completionResponse" in input and not "responseGuardrailResponse" in result:
            result["responseGuardrailResponse"] = {"action": "PASS"}
        if "embeddingQuery" in input and not "queryGuardrailResponse" in result:
            result["queryGuardrailResponse"] = {"action": "PASS"}
        if "imageGenerationQuery" in input and not "imageGenerationResponse" in input and not "queryGuardrailResponse" in result:
            result["queryGuardrailResponse"] = {"action": "PASS"}
        if "imageGenerationResponse" in input and not "responseGuardrailResponse" in result:
            result["responseGuardrailResponse"] = {"action": "PASS"}
        


        result["trace"] = trace.to_dict()

        return result

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logging.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PythonGuardrailServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)