import asyncio
import logging
import os
import sys
import traceback

from concurrent.futures import ThreadPoolExecutor
from pprint import pformat
from typing import AsyncIterator, cast, Optional
from uuid import uuid4

from dataiku.base.async_link import AsyncJavaLink
from dataiku.base.gpu_utils import log_nvidia_smi, log_shm_size
from dataiku.base.socket_block_link import parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import dkujson
from dataiku.huggingface.env_collector import collect_env
from dataiku.huggingface.model_path_utils import model_and_base_model_name_or_path_manager
from dataiku.huggingface.pipeline import create_mock_pipeline, create_model_pipeline, ModelPipeline
from dataiku.huggingface.types import HuggingFaceKernelCommand, ProcessSingleCommandModel, ProcessSingleCommand, ProcessSingleResponse, StartCommand
from dataiku.huggingface.utils import copy_request_for_logging, enable_hf_transfer, log_hf_debug_info
from dataiku.base.remoterun import is_running_remotely
# hugging face related imports NEED to be made after we set the envs variables HF_HUB_ENABLE_HF_TRANSFER and TRANSFORMERS_OFFLINE.
# the imports are done in the serve function. cf https://app.shortcut.com/dataiku/story/190851/make-sure-transformers-offline-is-set-early-enough

logger = logging.getLogger(__name__)


def log_exception(loop, context):
    exc = context.get('exception')
    if exc is None:
        exc = Exception(context.get('message'))
    logger.error(f"Caught exception: {exc}\n"
                 f"Context: {context}\n"
                 f"Stack trace: {''.join(traceback.format_exception(type(exc), exc, exc.__traceback__))}")


class HuggingFaceServer:

    def __init__(self):
        self.start_command_received = False
        self.event_loop = None
        self.started = False
        self.start_executor = ThreadPoolExecutor(1)  # with Ray + engine V0, the executor needs to be kept alive for the whole lifecycle of the model, otherwise Ray processes are killed on executor shutdown
        self.env_data = {}
        self.model_pipeline: Optional[ModelPipeline] = None

    def start(self, start_command: StartCommand):
        # Fix required to be able to start engine V1 in background thread with python 3.9 (https://github.com/vllm-project/vllm/issues/18816)
        assert self.event_loop is not None
        asyncio.set_event_loop(self.event_loop)

        python_version_info = sys.version_info
        if python_version_info.major >= 3 and python_version_info.minor >= 9:
            logger.info(f"Python version: {python_version_info.major}.{python_version_info.minor}.{python_version_info.micro}")
        else:
            logger.warning(f"Python version {python_version_info.major}.{python_version_info.minor}.{python_version_info.micro} is not supported. Local Hugging Face model inference requires Python >= 3.9.")
        hf_model_name = start_command.get('hfModelName')

        hf_handling_mode = start_command["hfHandlingMode"]
        use_dss_model_cache = start_command["useDSSModelCache"]
        batch_size = start_command["batchSize"]
        model_settings = start_command["modelSettings"]
        supports_image_inputs = start_command["supportsImageInputs"]
        expected_vllm_version = start_command["vllmVersion"]

        # VLLM_USE_V1 must be set before vllm is first imported
        # otherwise it won't be taken into account
        assert 'vllm' not in sys.modules
        if model_settings.get("vllmEngine") == "V0":
            os.environ["VLLM_USE_V1"] = "0"
        elif model_settings.get("vllmEngine") == "V1":
            os.environ["VLLM_USE_V1"] = "1"

        # make sure huggingface_hub is not imported otherwise all the following code would be no-op
        assert 'huggingface_hub' not in sys.modules
        log_shm_size(logger)
        if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"]:
            logger.info("Custom environment variable CUDA_VISIBLE_DEVICES set to: %s" % os.environ["CUDA_VISIBLE_DEVICES"])
        else:
            logger.info("Custom environment variable CUDA_VISIBLE_DEVICES not set: defaulting to using all available GPUs")
        log_nvidia_smi(display_topo_info=True, logger=logger, additional_message="'nvidia-smi topo -m' topology information of the system")
        log_nvidia_smi(display_topo_info=False, logger=logger, additional_message="'nvidia-smi' before loading the model")
        if use_dss_model_cache:
            os.environ["TRANSFORMERS_OFFLINE"] = "1"  # Disable HF model cache. Must set this before `from transformers import pipeline`
            logger.info("Using DSS model cache : \"TRANSFORMERS_OFFLINE\"=1")
            # TODO @llm: Note that a version.txt file containing a single number is still written at $TRANSFORMERS_CACHE/version.txt
        enable_hf_transfer(logger)

        from huggingface_hub import login

        log_hf_debug_info(logger)

        if not use_dss_model_cache:
            hf_api_key = start_command.get("hfApiKey")
            if hf_api_key:
                login(token=hf_api_key)

        try:
            # Collection of env data is done:
            # - Before loading the model, otherwise HF tokenizers will complain with:
            #   "The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...".
            # - After configuring HF hub, since we do set env variables that must be configured importing 'huggingface_hub'.
            logger.info("Collecting env data")
            self.env_data = collect_env()
            logger.info(f"Environment data {self.env_data}")
        except Exception:
            logger.exception("Failed to collect env data")

        with model_and_base_model_name_or_path_manager(start_command, model_settings) as (model_name_or_path, base_model_name_or_path, refiner_name_or_path):
            if start_command['fakeLLMServer']:
                self.model_pipeline = create_mock_pipeline(hf_handling_mode, model_settings)
            else:
                model_was_finetuned_with_dss = bool(start_command.get("savedModelId"))
                model_settings["hfRefinerPath"] = refiner_name_or_path
                self.model_pipeline = create_model_pipeline(hf_handling_mode, hf_model_name, model_name_or_path, base_model_name_or_path,
                                                            use_dss_model_cache, model_settings, batch_size, model_was_finetuned_with_dss, supports_image_inputs, expected_vllm_version)
            self.model_pipeline.model_tracking_data["model_source"] = "model-cache" if use_dss_model_cache else "huggingface-hub"

        log_nvidia_smi(display_topo_info=False, logger=logger, additional_message="'nvidia-smi' after loading the model")
        logger.info("Initialization done")

    async def process_query(self, request: ProcessSingleCommandModel) -> AsyncIterator[ProcessSingleResponse]:
        assert self.started, "Not started"
        assert self.model_pipeline is not None
        request_id = str(uuid4())
        logger.info("Processing request %s", request_id)
        logger.debug("Request %s payload: %s", request_id, copy_request_for_logging(request))  # should not create a huge string if log level not debug

        request = cast(ProcessSingleCommand, request)
        request["id"] = request_id
        try:
            async for resp in self.model_pipeline.run_single_async(request):
                yield resp
            logger.info("Sending response complete to request %s", request_id)
        except asyncio.CancelledError:
            # Note that at this point the model pipeline might still be running (i.e. only vLLM supports real
            # cancellation)
            logger.info("Cancellation of request %s was requested" % request_id)

    async def handler(self, command: HuggingFaceKernelCommand) -> AsyncIterator[ProcessSingleResponse]:
        if command["type"] == "start":
            assert not self.start_command_received, "Start command already received"
            self.start_command_received = True

            redacted_start_command: StartCommand = {**command}
            if redacted_start_command.get("hfApiKey") is not None:
                redacted_start_command["hfApiKey"] = "REDACTED"
            logger.info(f"Received start command:\n{pformat(redacted_start_command)}")
            # Run start in a different thread as it may take a long time to load the model
            self.event_loop = asyncio.get_running_loop()  # required to set the event loop in the background thread running the start command
            await asyncio.get_running_loop().run_in_executor(self.start_executor, self.start, command)
            assert self.model_pipeline is not None
            await self.model_pipeline.initialize_model()
            self.started = True
            logger.info("Ready to use model for inference. Waiting for commands")

        elif command["type"] == "collect-env":
            assert self.model_pipeline is not None
            # This command is called by the DSS backend to collect the environment data for WT1 *after* the model has been loaded.
            tracking_data = {
                "used_engine": self.model_pipeline.used_engine,
                **self.model_pipeline.model_tracking_data,
                **self.env_data
            }
            logger.info(f"Collected env data: {dkujson.dumps(tracking_data)}")
            yield {"trackingData": tracking_data}
        elif command["type"] == "get-used-engine":
            assert self.model_pipeline is not None
            logger.info(f"Getting the running engine {self.model_pipeline.used_engine}")
            yield { "usedEngine": self.model_pipeline.used_engine }
        elif command["type"] in ["process-embedding-query", "process-completion-query", "process-image-generation-query", "process-reranking-query"]:
            async for resp in self.process_query(command):
                yield resp
        else:
            raise Exception("Unknown command type: {type}".format(type=command["type"]))


if __name__ == "__main__":
    LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper()  # Set LOGLEVEL=DEBUG to debug
    logging.basicConfig(level=LOGLEVEL,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')

    # Disable vLLM usage tracking
    # See https://github.com/vllm-project/vllm/blob/v0.4.0.post1/vllm/usage/usage_lib.py#L39
    os.environ["DO_NOT_TRACK"] = "1"
    os.environ["VLLM_NO_USAGE_STATS"] = "1"

    # Disable Ray's usage tracking (Ray is currently a dependency of vLLM, but they might remove it in the future)
    # See https://github.com/ray-project/ray/blob/ray-2.10.0/python/ray/_private/usage/usage_lib.py#L384
    os.environ["RAY_USAGE_STATS_ENABLED"] = "0"

    logger.info("Starting HF server")
    watch_stdin()

    async def start_server():
        asyncio.get_running_loop().set_exception_handler(log_exception)

        port, secret, server_cert = parse_javalink_args()
        link = AsyncJavaLink(port, secret, server_cert=server_cert)
        server = HuggingFaceServer()
        await link.connect()
        await link.serve(server.handler)

        if is_running_remotely():
            # In container execution mode, we are not responsible for cleaning up resources => we can exit immediately
            logger.info("Exiting immediately in container execution mode")

            # Force-exit because there might be non-daemon threads still running (e.g. vLLM engine threads)
            os._exit(0)

        # vllm spawns children processes, and we rely on the Java backend to clean the whole process tree
        # so we need this python process to stay alive long enough after the link is closed for the Java to kill it along with its whole process tree
        # => we add a sleep here to prevent premature interpreter exit
        logger.info("Serving is done, waiting for the Java backend to kill the process tree...")
        await asyncio.sleep(30)

        # should never be reached
        logger.warning("HF server process should have been killed but is still alive, resources may leak upon interpreter exit")

    asyncio.run(start_server())
