import logging
import sys
import traceback
import json
import os

import multiprocessing as mp
from multiprocessing import Pool, TimeoutError

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging

from govern.core.handler import Handler, set_handler
from govern.core.migration_handler import MigrationHandler, set_migration_handler

from dataikuapi import GovernClient

def scriptRunner(command):
    ret = {}
    try:
        ticketSecret = command.get('ticketSecret')
        client = GovernClient('http://127.0.0.1:%s' % os.environ.get('DKU_GOVERNSERVER_PORT'), internal_ticket=ticketSecret)

        if command.get('type') == "logical-hook-request":
            hookPhase = command.get('hookPhase')
            newEnrichedArtifact = command.get('newEnrichedArtifact')
            existingEnrichedArtifact = command.get('existingEnrichedArtifact')
            handler = Handler(hookPhase, newEnrichedArtifact, existingEnrichedArtifact, client)
            set_handler(handler)
            hook = command.get('hook')
            script = hook.get('script')

            exec(script, {})
            ret['artifact'] = json.dumps(handler.artifact.json) if handler.artifact is not None else None
            ret['fieldMessages'] = handler.fieldMessages
            ret['status'] = handler.status
            ret['artifactIdsToUpdate'] = handler.artifactIdsToUpdate
            ret['message'] = handler.message
        elif command.get('type') == "migration-path-request":
            source_enriched_artifact = command.get('sourceEnrichedArtifact')
            target_enriched_blueprint_version = command.get('targetEnrichedBlueprintVersion')
            handler = MigrationHandler(source_enriched_artifact, target_enriched_blueprint_version, client)
            set_migration_handler(handler)
            migration = command.get('migrationPath')
            script = migration.get('script')

            exec(script, {})
            ret['migratedArtifact'] = json.dumps(handler.target_artifact.json) if handler.target_artifact is not None else None
            ret['status'] = handler.status
            ret['message'] = handler.message
        else:
            ret['status'] = 'ERROR'
            ret['message'] = 'Unexpected error: unknown command type "' + str(command.get('type')) + '". Contact an administrator to investigate.'
            
    except Exception as ex:
        stacktrace = traceback.format_exc(limit=2).splitlines()
        if (len(stacktrace) < 2):
            ret['message'] = str(e)
        else:
            ret['message'] = stacktrace[-1] + stacktrace[-2]
        ret['status'] = 'ERROR'

    return ret

class GovernServer:

    def __init__(self, link):
        self.link = link

    def start(self):
        pool = Pool(processes=1)
        while True:
            ret = {}
            try:
                command = self.link.read_json()
                timeout = int(command.get('timeout'))
                async_result = pool.apply_async(scriptRunner, (command, ))
                try:
                    ret = async_result.get(timeout/1000)
                except TimeoutError as te:
                    pool.terminate()
                    pool = Pool(processes=1)
                    ret['status'] = 'ERROR'
                    ret['message'] = 'The script took long to complete (' + str(timeout) + 'ms). Contact an administrator to investigate.'

            except Exception as e:
                stacktrace = traceback.format_exc(limit=2).splitlines()
                if (len(stacktrace) < 2):
                    ret['message'] = str(e)
                else:
                    ret['message'] = stacktrace[-1] + stacktrace[-2]
                ret['status'] = 'ERROR'

            finally:
                self.link.send_json(ret)


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()

    governServer = GovernServer(link)
    try:
        governServer.start()
    finally:
        link.close()

if __name__ == "__main__":
    if 'forkserver' in mp.get_all_start_methods():
        # We use "forkserver" because it is safer on MacOS due to this issue: https://stackoverflow.com/questions/55924761/worker-process-crashes-on-requests-get-when-data-is-put-into-input-queue-befor
        mp.set_start_method('forkserver')
    print('Using start method: ' + str(mp.get_start_method()))

    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
