import logging
import traceback
import json
import os
import asyncio

from concurrent.futures import ThreadPoolExecutor

from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

from govern.core.handler import Handler, set_handler
from govern.core.migration_handler import MigrationHandler, set_migration_handler
from govern.core.action_handler import ActionHandler, set_action_handler
from govern.core.artifact_action_handler import ArtifactActionHandler, set_artifact_action_handler
from govern.core.autogovernance_handler import AutogovernanceHandler, set_autogovernance_handler

from dataikuapi import GovernClient

logger = logging.getLogger('govern_python_server')


def script_runner(command):
    ret = {}
    try:
        command_type = command.get('type')
        ticket_secret = command.get('ticketSecret')
        client = GovernClient('http://127.0.0.1:%s' % os.environ.get('DKU_GOVERNSERVER_PORT'), internal_ticket=ticket_secret)
        auth_ctx_identifier = command.get('authCtxIdentifier')
        if command_type == 'logical-hook-request':
            hook_phase = command.get('hookPhase')
            new_enriched_artifact = command.get('newEnrichedArtifact')
            existing_enriched_artifact = command.get('existingEnrichedArtifact')
            handler = Handler(hook_phase, new_enriched_artifact, existing_enriched_artifact, auth_ctx_identifier, client)
            set_handler(handler)
            hook = command.get('hook')
            script = hook.get('script')

            # --- Pre-Execution Logging ---
            hook_name = hook.get('name')
            # for log IDs, using existing enriched artifact if it exists, otherwise use new enriched artifact
            artifact_for_log = command.get('existingEnrichedArtifact', command.get('newEnrichedArtifact', {})).get('artifact', {})
            artifact_id = artifact_for_log.get('id')
            bpv_id = str(artifact_for_log.get('blueprintVersionId'))
            cmd_log = "'%s' command: hook phase '%s', hook name '%s', artifact ID '%s', blueprint version ID '%s', user '%s'" % (command_type, hook_phase, hook_name, artifact_id, bpv_id, auth_ctx_identifier)
            log_start(cmd_log)

            exec(script, {})

            ret['artifact'] = json.dumps(handler.artifact.json) if handler.artifact is not None else None
            ret['fieldMessages'] = handler.fieldMessages
            ret['status'] = handler.status
            ret['artifactIdsToUpdate'] = handler.artifactIdsToUpdate
            ret['message'] = handler.message

            # --- Post-Execution Logging ---
            log_end(cmd_log)
        elif command_type == 'migration-path-request':
            source_enriched_artifact = command.get('sourceEnrichedArtifact')
            target_enriched_blueprint_version = command.get('targetEnrichedBlueprintVersion')
            handler = MigrationHandler(source_enriched_artifact, target_enriched_blueprint_version, auth_ctx_identifier, client)
            set_migration_handler(handler)
            migration = command.get('migrationPath')
            script = migration.get('script')

            # --- Pre-Execution Logging ---
            migration_path_id = command.get('migrationPath', {}).get('id')
            artifact_id = command.get('sourceEnrichedArtifact', {}).get('artifact', {}).get('id')
            bpv_id_from = str(command.get('migrationPath', {}).get('blueprintVersionIdFrom'))
            bpv_id_to = str(command.get('migrationPath', {}).get('blueprintVersionIdTo'))
            cmd_log = "'%s' command: migration path ID '%s', artifact ID '%s', blueprint version ID from: '%s', blueprint version ID to: '%s', user '%s'" % (command_type, migration_path_id, artifact_id, bpv_id_from, bpv_id_to, auth_ctx_identifier)
            log_start(cmd_log)
            exec(script, {})

            ret['migratedArtifact'] = json.dumps(handler.target_artifact.json) if handler.target_artifact is not None else None
            ret['status'] = handler.status
            ret['message'] = handler.message

            # --- Post-Execution Logging ---
            log_end(cmd_log)

        elif command_type == 'artifact-action-script-request':
            enriched_artifact = command.get('enrichedArtifact')
            params = command.get('params')
            handler = ArtifactActionHandler(enriched_artifact, auth_ctx_identifier, params, client)
            set_artifact_action_handler(handler)
            script = command.get('script')

            # --- Pre-Execution Logging ---
            action_id = command.get('actionId')
            artifact_id = command.get('enrichedArtifact', {}).get('artifact', {}).get('id')
            bpv_id = str(command.get('enrichedArtifact', {}).get('artifact', {}).get('blueprintVersionId'))
            cmd_log = "'%s' command: action ID '%s', artifact ID '%s', blueprint version ID: '%s', user '%s'" % (command_type, action_id, artifact_id, bpv_id, auth_ctx_identifier)
            log_start(cmd_log)

            exec(script, {})

            ret['status'] = handler.status
            ret['message'] = handler.message

            # --- Post-Execution Logging ---
            log_end(cmd_log)

        elif command_type == 'action-script-request':
            params = command.get('params')
            handler = ActionHandler(auth_ctx_identifier, params, client)
            set_action_handler(handler)
            script = command.get('script')

            # --- Pre-Execution Logging ---
            action_id = command.get('actionId')
            cmd_log = "'%s' command: action ID: '%s', user: '%s'" % (command_type, action_id, auth_ctx_identifier)
            log_start(cmd_log)

            exec(script, {})

            ret['status'] = handler.status
            ret['message'] = handler.message

            # --- Post-Execution Logging ---
            log_end(cmd_log)

        elif command_type == 'autogovernance-script-request':
            enriched_artifact = command.get('enrichedArtifact')
            handler = AutogovernanceHandler(enriched_artifact, client)
            set_autogovernance_handler(handler)
            script = command.get('script')

            # --- Pre-Execution Logging ---
            artifact_id = command.get('enrichedArtifact', {}).get('artifact', {}).get('id')
            cmd_log = "'%s' command: artifact ID '%s'" % (command_type, artifact_id)
            log_start(cmd_log)

            exec(script, {})

            ret['scriptOutput'] = json.dumps(handler.script_output.json) if handler.script_output is not None else None
            ret['status'] = handler.status
            ret['message'] = handler.message

            # --- Post-Execution Logging ---
            log_end(cmd_log)
        else:
            ret['status'] = 'ERROR'
            ret['message'] = 'Unexpected error: unknown command type \'' + str(command_type) + '\'. Contact an administrator to investigate.'

    except Exception as ex:
        ex_str = str(ex)
        ex_trace = str(''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)))
        error_msg = 'Caught exception while running python script: ' + ex_str + '\nStack trace: ' + ex_trace
        logger.error(error_msg)
        ret['status'] = 'ERROR'
        ret['message'] = error_msg

    return ret

def log_start(cmd_log):
    logger.info("Starting " + cmd_log)
def log_end(cmd_log):
    logger.info("End of " + cmd_log)

class GovernServer:

    async def handler(self, command):
        # run the user code in a thread to release the main event loop
        with ThreadPoolExecutor(max_workers=1) as executor:
            yield await asyncio.get_event_loop().run_in_executor(executor, script_runner, command)


def log_exception(loop, context):
    exc = context.get('exception')
    if exc is None:
        exc = Exception(context.get('message'))
    logger.error(
        'Caught exception: ' + str(exc) +
        '\nContext: ' + str(context) +
        '\nStack trace: ' + str(''.join(traceback.format_exception(type(exc), exc, exc.__traceback__)))
    )


if __name__ == '__main__':
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    logger.info('Starting Govern python execution kernel')

    watch_stdin()

    async def start_server():
        asyncio.get_event_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert, connectivity_test_timeout=None)
        server = GovernServer()

        await link.connect()
        await link.serve(server.handler)

    asyncio.run(start_server(), debug=True)
