import sys, json, os, socket, time, logging, traceback, signal
import threading
import requests
import zmq
from tornado import ioloop
from zmq.eventloop.zmqstream import ZMQStream
from dataiku.core.intercom import backend_json_call, backend_void_call
from .zmq_utils import Forwarder, ROUTER_DEALER_Forwarder, PUB_SUB_Forwarder, REQ_REP_Forwarder
from dataiku.base import remoterun
        
class ServerSideForwarder(object):
    def __init__(self, connection_file, env_lang, env_name, project_key, bundle_id, container_conf):
        self.connection_file = connection_file
        self.env_name = env_name
        self.env_lang = env_lang
        self.project_key = project_key
        self.bundle_id = bundle_id
        self.container_conf = container_conf

    def hb_forwarder(self):
        logging.info("Start heart beat listener")
        freq = 10
        poller = zmq.Poller()
        poller.register(self.callback_socket, zmq.POLLIN)
        try:
            while True:
                events = poller.poll(2 * freq * 1000)
                if len(events) == 0:
                    logging.info("Heartbeat stopped")
                    os._exit(1)
                # the socket in the events cannot be anything but self.callback_socket
                message = self.callback_socket.recv()
                # logging.info("HB %s" % message)
                self.callback_socket.send('pong'.encode('utf8'))
        except IOError as e:
            logging.info("Error heartbeating, exiting")
            traceback.print_exc()
            os._exit(1)
        finally:
            self.callback_socket.close()

    def initialize(self):
        with open(self.connection_file, 'r') as f:
            local_connection_file = json.loads(f.read())

        allowed_port_range = backend_json_call("jupyter/allowed-port-range")
        min_port = None
        max_port = None
        max_tries = 100
        if allowed_port_range['enabled']:
            min_port = allowed_port_range['start']
            max_port = allowed_port_range['end']
            max_tries = max_port - min_port
        # Make sure both are valid

        # start the forwarding (zmq-wise), ie relaying the sockets in the connection file
        port_pairs = []
        for port_type in ['shell_port', 'iopub_port', 'stdin_port', 'control_port', 'hb_port']:
            local_port = local_connection_file.get(port_type, None)
            if local_port is None or local_port == 0:
                continue
            remote_port = None # means bind to random
            port_pairs.append([local_port, remote_port, port_type[:-5]])
        
        def printout(m):
            logging.info(m)
            
        # bind on 127.0.0.1 for the jupyter-server-facing side and on all interfaces for the kernel-facing side
        def forward_ROUTER_DEALER(local_port, remote_port, port_type):
            return ROUTER_DEALER_Forwarder('127.0.0.1', local_port, '0.0.0.0', remote_port, port_type, printout, True, True, min_port, max_port)
     
        def forward_PUB_SUB(local_port, remote_port, port_type):
            return PUB_SUB_Forwarder('127.0.0.1', local_port, '0.0.0.0', remote_port, port_type, printout, True, True, min_port, max_port)
    
        def forward_REP_REQ(local_port, remote_port, port_type):
            return REQ_REP_Forwarder('127.0.0.1', local_port, '0.0.0.0', remote_port, port_type, printout, True, True, min_port, max_port)
            
        socket_forwarders = {'hb' : forward_REP_REQ, 'shell' : forward_ROUTER_DEALER, 'iopub' : forward_PUB_SUB, 'stdin' : forward_ROUTER_DEALER, 'control': forward_ROUTER_DEALER}
                
        for port_pair in port_pairs:
            local_port = port_pair[0]
            remote_port = port_pair[1]
            port_type = port_pair[2]
            logging.info("Relay port %s to %s on type %s" % (local_port, remote_port, port_type))
            
            socket_forwarder = socket_forwarders[port_type](local_port, remote_port, port_type)
            
            port_pair[1] = socket_forwarder.remote_port # retrieve what has been bound
            
        # swap the ports that the jupyter server knows, and that this forwarder now handles, for
        # the ports it opened for listening for the remote kernel
        for port_pair in port_pairs:
            local_connection_file['%s_port' % port_pair[2]] = port_pair[1]

        # and open a new socket for the comm to the remote kernel overseer (ie runner.py in the container)
        context = zmq.Context()
        self.callback_socket = context.socket(zmq.REP)
        self.signaling_socket = context.socket(zmq.PUB)
        if min_port is not None and max_port is not None:
            callback_port_selected = self.callback_socket.bind_to_random_port('tcp://*', min_port=min_port, max_port=max_port, max_tries=max_tries)
            signal_port_selected = self.signaling_socket.bind_to_random_port('tcp://*', min_port=min_port, max_port=max_port, max_tries=max_tries)
        else:
            callback_port_selected = self.callback_socket.bind_to_random_port('tcp://*', min_port=10000, max_port=30000, max_tries=100)
            signal_port_selected = self.signaling_socket.bind_to_random_port('tcp://*', min_port=10000, max_port=30000, max_tries=100)
        local_connection_file['relayPort'] = callback_port_selected
        local_connection_file['signalPort'] = signal_port_selected
    
        remote_kernel = backend_json_call("jupyter/start-remote-kernel", data={
            "contextProjectKey" : remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY"),
            "connectionFile" : json.dumps(local_connection_file),
            "projectKey" : self.project_key,
            "bundleId" : self.bundle_id,
            "envLang" : self.env_lang,
            "envName" : self.env_name,
            "containerConf" : self.container_conf
        })
        
        logging.info("Started, got : %s" % json.dumps(remote_kernel))
        self.batch_id = remote_kernel['id']

        # The Jupyter kernel shutdown sequence is to send, in order SIGINT, SIGTERM, SIGKIILL
        #
        # Before starting to block on remote kernel ACK, we install a SIGINT handler
        # that will tell the backend to destroy the pod.
        # Without that, the SIGINT would kill the forwarder, but leave the pod running

        def signal_handler_abort_pod(signum, frame):
            import sys
            sys.stderr.write("Server side forwarder got signal %s, need to tell backend to destroy pod\n" % signum)
            backend_void_call("jupyter/abort-remote-kernel", data={
                "batchId" : self.batch_id
            })
            sys.stderr.write("Server side forwarder sent abort command")
            sys.exit(-signum)

        signal.signal(signal.SIGINT, signal_handler_abort_pod)

        # start the thread that polls the backend-side thread, to kill this process whenever that thread dies
        # this has to be started before we block on the remote kernel ACK
        self.start_wait_for_remote_kernel_death()
        
        # block until the remote end has started its kernel
        message = self.callback_socket.recv()
        logging.info("Got %s" % message)
        self.callback_socket.send('ok'.encode('utf8'))

        # start the heartbeating
        hb_thread = threading.Thread(name="forwarder-watcher", target=self.hb_forwarder)
        hb_thread.daemon = True
        hb_thread.start()

        # Now that we have started, SIGINT must instead be transmitted to the remote kernel, as 
        # it is used for the interrupt command
        def sigint_handler_forward_to_remote_kernel(signum, frame):
            print('Signal handler called with signal %s' % signum)
            self.signaling_socket.send('sigint'.encode('utf8'))
            print("Continuing")

        signal.signal(signal.SIGINT, sigint_handler_forward_to_remote_kernel)

        # Now that we are connected, we don't need to abort the pod anymore upon kernel shutdown
        # because the loss of connection will cause the kernel to die
        # (We could also redirect SIGTERM to signal_handler_abort_pod, but it's not strictly needed)
        #   signal.signal(signal.SIGTERM, signal_handler_abort_pod)        

    def start(self):                
        # ioloop is synchronous; polling the state of the remote container via the backend is in a thread
        # the polling of the remote kernel is started earlier, so that the forwarder doesn't hang
        # on waiting the ACK from the remote kernel startup
        logging.info("Starting IOLoop")
        try:
            ioloop.IOLoop.instance().start()
        except:
            logging.error("IOLoop failure")
            traceback.print_exc()
            os._exit(1)
        
    def wait_for_remote_kernel_death(self):
        def get_status():
            logging.debug("Polling pod state")
            remote_kernel_status = backend_json_call("jupyter/poll-remote-kernel", data={
                "batchId" : self.batch_id
            })
            logging.debug("Polled pod state, got : %s" % json.dumps(remote_kernel_status))
            return remote_kernel_status
    
        try:
            status = None
            state = None
            while True:
                status = get_status()
                state = status.get("state", None)

                if state in ["dead", "success", "failed"]:
                    break
                time.sleep(10)

            logging.info("End of pod polling loop, last status is %s" % status)
            if state == "dead" or state == "failed":
                os._exit(1)
            else:
                os._exit(0)
        except:
            logging.error("Polling ended in failure")
            traceback.print_exc()
            os._exit(1)
            
        
    def start_wait_for_remote_kernel_death(self):                
        t = threading.Thread(target=self.wait_for_remote_kernel_death)
        t.daemon = True
        t.start()
