import copy
import inspect
import sys
import json
import calendar, datetime, time
import traceback, logging
import asyncio
import json
import sys
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor

from dataiku.base.async_link import AsyncJavaLink
from dataiku.llm.tracing import new_trace
from dataiku.core import debugging
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter, DSSLLMCompletionResponse

from dataiku.base.utils import watch_stdin, get_clazz_in_code, get_json_friendly_error, get_argspec
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.core import dkuio
from dataiku.core.dataset import Dataset
from .base import BaseAgentTool

import pandas as pd, numpy as np


def get_redacted_command(command):
    redacted_command = copy.deepcopy(command)
    if "config" not in redacted_command:
        return redacted_command
    config = redacted_command["config"]

    # plugin params with type PASSWORD
    if password_params := command.get("pluginConfig", {}).get("dkuPasswordParams"):
        for key in password_params:
            if key in config:
                config[key] = "**redacted**"
        del redacted_command["pluginConfig"]["dkuPasswordParams"]

    return redacted_command


class PythonAgentToolServer:
    def __init__(self):
        self.started = False
        self.executor = ThreadPoolExecutor(32)

    async def handler(self, command):
        logging.info("PythonAgentToolServer handler received command: %s" % get_redacted_command(command))
        if command["type"] == "start":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.start, command
            )
        elif command["type"] == "invoke":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.invoke, command
            )

        elif command["type"] == "get-descriptor":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.get_descriptor, command
            )

        elif command["type"] == "load-sample-query":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self.load_sample_query, command
            )
        else:
            raise Exception("Unknown command type: %s" % command["type"])


    def start(self, start_command):
        assert not self.started, "Already started"

        project_key = start_command.get("projectKey", {})

        if "code" in start_command and start_command["code"] is not None:
            code = start_command.get("code", {})
            clazz = get_clazz_in_code(code, BaseAgentTool, strict_module=True)
            self.instance = clazz()
        else:
            raise Exception("not implemented")

        if hasattr(self.instance, "set_config"):
            self.instance.set_config(start_command.get("config", {}), start_command.get("pluginConfig", {}))

        self.started = True
        return { "ok" : True }

    def invoke(self, process_command):
        assert self.started, "Not started"  
        input = process_command.get("input")

        trace = new_trace("PYTHON_AGENT_TOOL_CALL")
        trace.__enter__()

        try:
            logging.info("Invoking tool")
            result = self.instance.invoke(input, trace)
            logging.info("Got Tool result")
        except Exception as e:
            logging.exception("Tool failed")
            result = {"error": type(e).__name__ + ": " + str(e)}

        trace.__exit__(None, None, None)

        result["trace"] = trace.to_dict()

        return result

    def get_descriptor(self, process_command):
        assert self.started, "Not started"  
        tool = process_command.get("tool")

        ret = self.instance.get_descriptor(tool)

        logging.info("Got Tool descriptor: %s" % ret)

        return ret

    def load_sample_query(self, process_command):
        assert self.started, "Not started"
        tool = process_command.get("tool")

        ret = self.instance.load_sample_query(tool)

        logging.info("Got sample query: %s" % ret)

        return ret

def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logging.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = PythonAgentToolServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)