import asyncio
import base64
import logging
import os
import socket
import ssl
import struct
import tempfile
import threading
import time
from asyncio import CancelledError
from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

from dataiku.core import dkujson
from dataiku.core import intercom

logger = logging.getLogger(__name__)
INT32_STRUCT = struct.Struct('>i')


class AsyncBlockIO(object):
    def __init__(self):
        self._writer: Optional[asyncio.StreamWriter] = None
        self._reader: Optional[asyncio.StreamReader] = None

    async def send_block(self, block: bytes) -> None:
        await self._send_int(len(block))
        await self._send_raw_bytes(block)
        if self._writer is None:
            raise IOError("No writer")
        await self._writer.drain()

    async def send_json(self, obj: Any, cls=None) -> None:
        await self.send_string(dkujson.dumps(obj, cls=cls))

    async def send_string(self, str_val: str) -> None:
        await self.send_block(str_val.encode("utf-8"))

    async def _send_int(self, value: int) -> None:
        await self._send_raw_bytes(INT32_STRUCT.pack(value))

    async def _send_raw_bytes(self, data: bytes) -> None:
        if self._writer is None:
            raise IOError("No writer")
        self._writer.write(data)

    async def read_block(self, read_timeout: Optional[float] = None) -> bytes:
        block_size = await self._read_int(read_timeout)
        return await self._read_raw_bytes(block_size, read_timeout)

    async def read_json(self, read_timeout: Optional[float] = None) -> Any:
        return dkujson.loads(await self.read_string(read_timeout))

    async def read_string(self, read_timeout: Optional[float] = None) -> str:
        data = await self.read_block(read_timeout)
        return data.decode("utf-8")

    async def _read_int(self, read_timeout: Optional[float]) -> int:
        data = await self._read_raw_bytes(4, read_timeout)
        return INT32_STRUCT.unpack(data)[0]

    async def _read_raw_bytes(self, size: int, read_timeout: Optional[float]) -> bytes:
        """
        Read exactly `size` bytes from the reader.

        The read timeout does not apply to the whole read operation but to each individual read operation on the socket.
        In short, if no bytes have been received for `read_timeout` seconds, a TimeoutError will be raised. The purpose
        of this timeout is to detect situations where the connection or the other side of the link is dead or hang, and
        not to put a time limit on the high-level operation that is being performed (e.g. read a huge block).
        """

        chunks = []
        current_size = 0

        while current_size < size:
            if self._reader is None:
                raise IOError("No reader")
            try:
                chunk = await asyncio.wait_for(self._reader.read(size - current_size), timeout=read_timeout)
            except asyncio.TimeoutError:
                self._test_connectivity()
                raise IOError(f"No data received after {read_timeout}s ({current_size}/{size} bytes read)")
            except Exception:
                self._test_connectivity()
                raise

            if not chunk:
                self._test_connectivity()
                raise EOFError("Could not read data (end of stream)")

            chunks.append(chunk)
            current_size += len(chunk)

        return b"".join(chunks)

    def _test_connectivity(self):
        """
        Test connectivity with remote server,
        meant to be overridden in children classes
        """
        pass


class AbstractAsyncSocketBlockLink(AsyncBlockIO):
    def __init__(self):
        super().__init__()
        self._async_server = None

    async def serve(self, async_handler, read_timeout: float = 60, keepalive_interval: float = 5):
        if self._async_server is not None:
            raise IOError("Server is already running")
        self._async_server = AsyncServer(self,
                                         handler=async_handler,
                                         read_timeout=read_timeout,
                                         keepalive_interval=keepalive_interval)
        return await self._async_server.serve()

    async def close(self):
        logger.info(f"Closing async block link...")
        self._reader = None
        if self._writer:
            self._writer.close()
            await self._writer.wait_closed()
            self._writer = None
        logger.info(f"Closed async block link")

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()


# This class is not used in DSS, it is only here to support writing unit tests for the server/client communication
class AsyncSocketBlockLinkServer(AbstractAsyncSocketBlockLink):
    def __init__(self, secret, timeout=60, host=None):
        super().__init__()
        self.secret = secret
        self.timeout = timeout
        self.host = host or socket.gethostname()
        self.server = None
        self.future = asyncio.Future()

    async def listen(self):
        logger.info("Starting server...")
        serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        serversocket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
        serversocket.setblocking(False)
        serversocket.bind((self.host, 0))
        serversocket.listen(1)

        self.server = await asyncio.start_server(self._accept, sock=serversocket)
        host, port = serversocket.getsockname()
        logger.info(f"Server is listening on {host}:{port}")
        return port

    async def close(self):
        await super().close()
        if self.server is not None:
            self.server.close()
            await self.server.wait_closed()

    async def accept(self):
        await asyncio.wait_for(self.future, timeout=self.timeout)

    async def _accept(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        logger.info("Client is connecting ...")
        # Only accept the first connection, ignore others, we don't care, it's used for testing only
        if self._writer is not None or self._reader is not None:
            logger.error("Client already connected, ignore new connection")
            writer.close()
            return

        self._writer = writer
        self._reader = reader

        sock = writer.get_extra_info("socket")
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

        try:
            received_secret = await self.read_string(read_timeout=self.timeout)
            if received_secret != self.secret:
                self.future.set_exception(IOError(f"Invalid secret"))
            logger.info(f"Client {writer.get_extra_info('peername')} is connected")
            self.future.set_result(None)
        except Exception as e:
            self.future.set_exception(e)
            await self.close()


class AsyncSocketBlockLinkClient(AbstractAsyncSocketBlockLink):
    def __init__(self, host: str, port: int, secret: str, connection_timeout=60,
                 server_cert_path: Optional[str] = None, connectivity_test_timeout: Optional[int] = 10):
        super().__init__()
        self.host = host
        self.port = port
        self.secret = secret
        self.connection_timeout = connection_timeout
        self.server_cert_path = server_cert_path

        self.connectivity_test_timeout = connectivity_test_timeout
        self.host_with_resolved_addr = host

    async def connect(self):
        ip_addr = await asyncio.get_event_loop().run_in_executor(
            None, socket.gethostbyname, self.host
        )
        self.host_with_resolved_addr = f"{self.host} ({ip_addr})" if self.host != ip_addr else self.host
        logger.info(f"Connecting to {self.host_with_resolved_addr} at port {self.port} (timeout={self.connection_timeout}s"
                    f", ssl={'enabled' if self.server_cert_path else 'disabled'})")

        ssl_context = None
        if self.server_cert_path is not None:
            ssl_context = ssl.create_default_context(capath=self.server_cert_path)
            ssl_context.load_verify_locations(self.server_cert_path)
        try:
            self._reader, self._writer = await asyncio.wait_for(
                asyncio.open_connection(
                    ip_addr,
                    self.port,
                    ssl=ssl_context,
                    server_hostname=self.host if ssl_context else None
                ),
                timeout=self.connection_timeout
            )
            sock = self._writer.get_extra_info("socket")
            sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

            logger.info("Sending secret")
            await self.send_string(self.secret)
            logger.info(f"Connected to {self.host_with_resolved_addr} at port {self.port}")
        except asyncio.TimeoutError:
            self._test_connectivity()
            raise IOError(f"No data received after {self.connection_timeout} seconds")
        except Exception as e:
            self._test_connectivity()
            raise IOError(f"Connection failed: {str(e)}")

    def _test_connectivity(self):
        if self.connectivity_test_timeout is None:
            # fuse to disable the connectivity test in Govern
            # intercom.backend_get_call won't work in Govern because remote run logic is not implemented
            return

        def test_connectivity():
            host_info = f"(resolved host: {self.host_with_resolved_addr}, port: {self.port})"
            payload = {"executionId": os.getenv("DKU_EXECUTION_ID"), "host": self.host, "port": self.port}

            logger.info(f"Sending API call to DSS backend to test connectivity {host_info}")
            try:
                response = intercom.backend_get_call("test-connectivity", params=payload, timeout=self.connectivity_test_timeout)
                logger.info(f"Successfully reached DSS backend, got response: {response} {host_info}")
            except Exception:
                logger.warning(f"DSS backend could not be reached {host_info}", exc_info=True)

        # Using a regular threading.Thread here because:
        # - we don't want to await the outcome synchronously to allow the exceptions to be raised immediately
        #   while the connectivity test starts running in the background.
        # - without the need to await the outcome, there is no reason to involve the ioloop and it's less bug-prone not to involve it
        # - thread is not daemonic, so the interpreter will still not exit by itself until the connectivity test has completed
        threading.Thread(target=test_connectivity).start()


class AsyncJavaLink(AsyncSocketBlockLinkClient):
    def __init__(self, port, secret, server_cert=None, connectivity_test_timeout: Optional[int] = 10):
        server_cert_path = None
        if not server_cert:
            logger.info("no cert found")
        elif server_cert == "NA":
            logger.info("cert is NA")
        else:
            logger.info("cert found")
            with tempfile.NamedTemporaryFile(prefix="encrypted-rpc-cert-", delete=False) as f:
                if server_cert.startswith("b64:"):
                    server_cert = server_cert[4:]
                    server_cert = base64.b64decode(server_cert).decode("utf8")
                f.write(server_cert.encode("utf-8"))
                server_cert_path = f.name

        dss_host = os.getenv("DKU_BACKEND_HOST", "localhost")
        super().__init__(dss_host, port, secret, server_cert_path=server_cert_path, connectivity_test_timeout=connectivity_test_timeout)


class AsyncServer(object):
    """
    Async server on top of a block link. All messages are JSON.

    Request protocol:
    - {"requestId": string, "action": "START"} + request payload      // send a request
    - {"requestId": string, "action": "CANCEL"}                       // cancel a request
    - {"action": "PING"}                                              // no-op, just to keep the TCP socket alive
    - {"action": "STOP"}                                              // tell server to shut down

    Response protocol:
    - {"requestId": string, "state": "RUNNING"} + request payload  // send a response to a request (can be sent multiple times)
    - {"requestId": string, "state": "CANCELLED"}                  // tell that a request was properly cancelled
    - {"requestId": string, "state": "SUCCEEDED"}                  // tell that a request is done
    - {"requestId": string, "state": "FAILED"} + request payload   // tell that a request failed
    - {"state": "ALIVE"}                                            // no-op, just to keep the TCP socket alive
    - {"state": "STOPPED"}                                         // acknowledge graceful shutdown
    """

    def __init__(self, async_link: AbstractAsyncSocketBlockLink, handler, read_timeout, keepalive_interval):
        self.keepalive_interval = keepalive_interval
        self.async_link = async_link
        self.response_queue = asyncio.Queue()
        self.handler = handler
        self.active_tasks = {}
        self.read_timeout = read_timeout

    async def receiver_loop(self):
        try:
            while True:
                request = await self.async_link.read_json(read_timeout=self.read_timeout)
                request_id = request.get("requestId")
                request_action = request.get("action")
                if request_action == "CANCEL":
                    if request_id is None:
                        raise ValueError("request id cannot be null")
                    task = self.active_tasks.get(request_id)
                    if task is not None:
                        task.cancel()
                elif request_action == "START":
                    if request_id is None:
                        raise ValueError("request id cannot be null")
                    data = await self.async_link.read_json(read_timeout=self.read_timeout)
                    task = asyncio.create_task(
                        self.handle_request(request_id, data)
                    )
                    task.add_done_callback(self.get_clean_task_callback(request_id))
                    self.active_tasks[request_id] = task
                elif request_action == "STOP":
                    logger.info("Stopping receiver loop")
                    break
                elif request_action == "PING":
                    continue
                else:
                    raise ValueError("Invalid request action: {}".format(request_action))
        except Exception:
            logger.exception("Unexpected error in receiver loop")
            raise
        finally:
            try:
                # whatever had happened, cancel all the tasks before stopping
                # consider using a payload in STOP message if you want another behavior
                for request_id, task in self.active_tasks.items():
                    logger.info("Cancelling request %s", request_id)
                    task.cancel()
            except Exception:
                logger.exception("Error when cancelling requests")
                raise
            finally:
                self.response_queue.put_nowait(None)

    async def responder_loop(self):
        while True:
            try:
                response = await asyncio.wait_for(self.response_queue.get(), timeout=self.keepalive_interval)
            except asyncio.TimeoutError:
                # FIXME(clean): could be replaced by proper TCP keep alive configuration
                # See comment https://github.com/dataiku/dip/pull/28758#issuecomment-2082977085
                await self.async_link.send_json({"state": "ALIVE"})
                continue
            if response is None:
                logger.info("Stopping responder loop")
                break
            request_id, request_type, json_payload = response
            await self.async_link.send_json({"requestId": request_id, "state": request_type})
            if request_type in ["RUNNING", "FAILED"]:
                await self.async_link.send_string(json_payload)
        await self.async_link.send_json({"state": "STOPPED"})

    async def serve(self):
        logger.info("Starting AsyncServer")
        loops = [("receiver", self.receiver_loop()), ("responder", self.responder_loop())]
        result = await asyncio.gather(*[l[1] for l in loops], return_exceptions=True)
        should_raise = False
        for i, r in enumerate(result):
            if isinstance(r, Exception):
                should_raise = True
                logger.exception("Error in %s loop", loops[i][0], exc_info=r)
        if should_raise:
            raise Exception("AsyncServer crashed")

    async def handle_request(self, request_id, request_payload):
        try:
            async for response_payload in self.handler(request_payload):
                try:
                    self.enqueue_response(request_id, "RUNNING", response_payload)
                except Exception:
                    logger.exception("Payload could not be sent")
                    # Let the original error bubble to fail the request
                    raise
            self.enqueue_response(request_id, "SUCCEEDED", None)
        except CancelledError:
            logger.info("Cancelled request")
            self.enqueue_response(request_id, "CANCELLED", None)
        except FatalException as e:
            cls = e.__class__
            message = {
                "message": str(e),
                "pythonExceptionPath": cls.__module__ + "." + cls.__qualname__
            }
            logger.exception("Request {request_id} failed with fatal error: {message}".format(request_id=request_id,
                                                                                              message=message["message"]))
            self.enqueue_response(request_id, "FAILED", message)
            # Stop gracefully the link
            self.response_queue.put_nowait(None)
            # Kill eventually the process after some time
            t = threading.Thread(target=schedule_shutdown, args=(message["message"],))
            t.start()
        except BaseException as e:
            logger.exception("Error while handling request")
            cls = e.__class__
            message = {
                "message": str(e),
                "pythonExceptionPath": cls.__module__ + "." + cls.__qualname__
            }
            self.enqueue_response(request_id, "FAILED", message)

    def enqueue_response(self, request_id, request_type, payload):
        json_payload = dkujson.dumps(payload)
        self.response_queue.put_nowait((request_id, request_type, json_payload))

    def get_clean_task_callback(self, request_id):
        def clean_task(*args, **kwargs):
            if request_id in self.active_tasks:
                del self.active_tasks[request_id]

        return clean_task


def schedule_shutdown(reason):
    countdown_seconds = int(os.environ.get("DKU_ASYNC_LINK_SHUTDOWN_DELAY", "30"))
    logger.info(
        "Impossible to recover from fatal error. Starting countdown {countdown}s to kill process. Error: {reason}".format(
            reason=reason, countdown=countdown_seconds))
    time.sleep(countdown_seconds)
    logger.info("Shutdown process")
    os._exit(1)


class FatalException(Exception):
    """
    Exception related to fatal errors from which the server cannot recover
    """


def call_with_fatal_exception(call: Callable,
                              exception_type: Union[Type[Exception], Tuple[Type[Exception], ...]]) -> Any:
    try:
        return call()
    except exception_type as err:
        raise FatalException("Fatal exception: {0}".format(str(err))) from err
