import asyncio
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.compat import ImpCompat
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import get_clazz_in_module, watch_stdin
from dataiku.core import debugging
from dataiku.project_standards import ProjectStandardsCheckSpec

logger = logging.getLogger(__name__)


class ProjectStandardsServer:
    def __init__(self):
        self.plugin_folder = None
        self.plugin_config = None
        self.started = False
        self.executor = ThreadPoolExecutor(1)

    def _init_kernel(self, plugin_folder, plugin_config):
        assert not self.started, "Already started"
        self.plugin_folder = plugin_folder
        self.plugin_config = plugin_config
        self.started = True
        logger.info("Project standards kernel initialized")
        return {"type": "InitKernelServerDone"}

    def _execute_check(self, project_key, original_project_key, id, code_file_path, config):
        # get the project standards check object
        logger.info("Loading python code for check " + id)
        try:
            module_name = os.path.basename(os.path.dirname(code_file_path))
            full_code_file_path = os.path.join(self.plugin_folder, code_file_path)
            module = ImpCompat.load_source(module_name, full_code_file_path)
            clazz = get_clazz_in_module(module, ProjectStandardsCheckSpec)
            project_standards_check = clazz(project_key, original_project_key, config, self.plugin_config)
        except Exception as e:
            traceback.print_exc()
            traceback.print_stack()
            logger.error(e)
            # TODO(@project-standards): add more info about the error (stack, other)?
            return {"type": "CheckResult", "error": "Unexpected error: {}".format(e)}

        logger.info("Executing check " + id)
        try:
            run_result = project_standards_check.run()
            return {"type": "CheckResult", "checkRunResult": run_result}
        except Exception as e:
            traceback.print_exc()
            traceback.print_stack()
            logger.error(e)
            return {"type": "CheckResult", "error": "Error while running check: {}".format(e)}

    async def handler(self, command):
        if command["type"] == "InitKernelServer":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor, self._init_kernel, command["pluginFolderPath"], command["pluginConfig"]
            )
        elif command["type"] == "ExecuteCheck":
            yield await asyncio.get_running_loop().run_in_executor(
                self.executor,
                self._execute_check,
                command["projectKey"],
                command["originalProjectKey"],
                command["checkId"],
                command["codeFileRelativePath"],
                command["checkParams"],
            )
        else:
            raise Exception("Unknown command type: %s" % command["type"])


def log_exception(loop, context):
    exc = context.get("exception")
    if exc is None:
        exc = Exception(context.get("message"))
    logging.error(
        f"Caught exception: {exc}\n"
        f"Context: {context}\n"
        f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}"
    )


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s"
    )
    debugging.install_handler()
    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)
        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = ProjectStandardsServer()
        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
