import asyncio
import json
import logging
import os
from abc import abstractmethod
from typing import AsyncIterator

import transformers
from vllm import AsyncLLMEngine, AsyncEngineArgs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels

from dataiku.base.utils import package_is_at_least
from dataiku.huggingface.vllm_backend.utils import get_model_config
from dataiku.base.async_link import FatalException
from dataiku.huggingface.pipeline import ModelPipeline
from dataiku.huggingface.types import SingleCommand, SingleResponse
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.env_collector import extract_model_architecture

logger = logging.getLogger(__name__)


class ModelPipelineVLLM(ModelPipeline[SingleCommand, SingleResponse]):
    engine_client: AsyncLLMEngine

    def __init__(
            self,
            model_name_or_path,
            base_model_name_or_path,
            model_settings: ModelSettings,
    ):
        super().__init__()
        self.log_stat_task = None
        self.model_settings = model_settings

        try:
            import vllm._moe_C
        except Exception:
            logger.exception("Some VLLM features like MoE models may not work on this OS, make sure to use Almalinux 9.")

        if base_model_name_or_path is not None:
            # LoRa adapter model
            self.model_to_load = base_model_name_or_path
            self.lora_path = self._resolve_lora_path(model_name_or_path)
            logger.info(f"Loading LoRA adapter from local path: {self.lora_path}")
        else:
            # Not an adapter model
            self.model_to_load = model_name_or_path
            self.lora_path = None

        logger.info("Loading model config")
        self.transformers_model_config = transformers.PretrainedConfig.from_pretrained(self.model_to_load)
        logger.info("Model config loaded")

        self.model_tracking_data["model_architecture"] = extract_model_architecture(self.transformers_model_config)
        self.model_tracking_data["used_engine"] = "vllm"

    @abstractmethod
    async def build_openai_server(self, openai_serving_models: OpenAIServingModels):
        raise NotImplementedError

    @abstractmethod
    async def run_test_query(self):
        raise NotImplementedError

    @abstractmethod
    async def run_query(self, request: SingleCommand) -> AsyncIterator[SingleResponse]:
        raise NotImplementedError
        yield

    def load_model(self, load_params: dict):
        logger.info(f"Loading model with args {load_params}")

        import vllm
        if hasattr(vllm.engine, "arg_utils") and hasattr(vllm.engine.arg_utils, "_warn_or_fallback"):
            # remove this hack after next vllm bump (not necessary anymore after https://github.com/vllm-project/vllm/pull/23298)
            from vllm.engine.arg_utils import _warn_or_fallback as _original_warn_or_fallback

            def _patched_warn_or_fallback(feature_name: str):
                if feature_name == "Engine in background thread":
                    logger.info("Overriding vllm _is_v1_supported_oracle to allow running engine V1 in background thread")
                    return False
                return _original_warn_or_fallback(feature_name)
            import vllm.engine.arg_utils
            vllm.engine.arg_utils._warn_or_fallback = _patched_warn_or_fallback
        else:
            logger.info("No need to patch vllm _warn_or_fallback")


        if hasattr(vllm, 'config') and hasattr(vllm.config, 'SchedulerConfig') and hasattr(vllm.config.SchedulerConfig, 'verify_max_model_len'):
            # Workaround https://github.com/vllm-project/vllm/issues/28981 that breaks vLLM on Mac
            # The underlying issue has been noticed in vLLM 0.11.1 but I don't know exactly from which version it is present
            from vllm.config.scheduler import SchedulerConfig
            _origin_verify_max_model_len = SchedulerConfig.verify_max_model_len
            logger.info("Patching vllm SchedulerConfig.verify_max_model_len() to autofix max_num_batched_tokens when chunked prefill is disabled (eg. on Mac/ARM)")
            def _patched_verify_max_model_len(self, max_model_len: int):
                if self.max_num_batched_tokens is not None and self.max_num_batched_tokens < max_model_len and not self.enable_chunked_prefill:
                    logger.info(f"Autofixing SchedulerConfig.max_num_batched_tokens from {self.max_num_batched_tokens} to {max_model_len} to match model max length since chunked prefill is disabled")
                    self.max_num_batched_tokens = max_model_len
                return _origin_verify_max_model_len(self, max_model_len)
            SchedulerConfig.verify_max_model_len = _patched_verify_max_model_len
        else:
            logger.info("Could not patch vllm SchedulerConfig.verify_max_model_len(): method not found")

        if self.model_settings.get("overriddenSettings"):
            try:
                overridden_settings = json.loads(self.model_settings.get("overriddenSettings"))
                load_params = {
                    **load_params,
                    **overridden_settings
                }
                logger.info(f"Overriding settings: {overridden_settings}")
            except Exception as e:
                raise FatalException(f"Tried to add custom settings but their parsing failed, provide a valid JSON to be taken into account, {e}")

        self.engine_client = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**load_params))
        logger.info("Model loaded")

    async def initialize_model(self):
        await self.engine_client.reset_mm_cache()  # inspired from https://github.com/vllm-project/vllm/blob/5fbbfe9a4c13094ad72ed3d6b4ef208a7ddc0fd7/vllm/entrypoints/openai/api_server.py#L192
        vllm_model_config = await get_model_config(self.engine_client)

        # -------------------------------
        # OPENAI SERVER
        from vllm.entrypoints.openai.serving_models import OpenAIServingModels, LoRAModulePath, BaseModelPath

        if self.lora_path:
            # 'base_model' is directly the base model, but it's not queried
            base_model_path = BaseModelPath(name="base_model", model_path="fake_path")
        else:
            # 'model' is directly the ID of the queried model
            base_model_path = BaseModelPath(name="model", model_path="fake_path")

        import vllm
        if package_is_at_least(vllm, '0.11.1'):
            openai_serving_models = OpenAIServingModels(
                self.engine_client,
                base_model_paths=[base_model_path],
                lora_modules=(
                    # Add a model ID 'model' that will be queried when LoRA is enabled
                    [LoRAModulePath(name="model", path=self.lora_path)]
                    if self.lora_path
                    else []
                )
            )
        else:
            openai_serving_models = OpenAIServingModels(
                self.engine_client,
                vllm_model_config, # Removed in vLLM 0.11.1
                base_model_paths=[base_model_path],
                lora_modules=(
                    # Add a model ID 'model' that will be queried when LoRA is enabled
                    [LoRAModulePath(name="model", path=self.lora_path)]
                    if self.lora_path
                    else []
                )
            )

        if self.lora_path:
            await openai_serving_models.init_static_loras()

        await self.build_openai_server(openai_serving_models)

        # -------------------------------
        # START THE STAT LOGGER (V1 engine doesn't log anything by default)
        from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine

        if isinstance(self.engine_client, V1AsyncLLMEngine):
            logger.info("VLLM V1 is being used, starting stat logger")

            async def _force_log():
                from vllm import envs
                while True:
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
                    await self.engine_client.do_log_stats()

            self.log_stat_task = asyncio.create_task(_force_log())
        else:
            logger.info("VLLM V0 is being used")
        # -------------------------------

        await self.run_test_query()

    async def run_single_async(
            self, request: SingleCommand
    ) -> AsyncIterator[SingleResponse]:
        exceptions_to_catch = []

        from vllm.v1.engine.exceptions import EngineDeadError  # engine V1
        exceptions_to_catch.append(EngineDeadError)

        try:
            # Does not exist in vLLM 0.11.0
            from vllm.engine.async_llm_engine import AsyncEngineDeadError  # engine V0
            exceptions_to_catch.append(AsyncEngineDeadError)
        except ImportError:
            logger.info("Recent vLLM version, no need to catch AsyncEngineDeadError")
            pass

        try:
            async for resp_or_chunk in self.run_query(request):
                yield resp_or_chunk
        except tuple(exceptions_to_catch) as err:
            raise FatalException("Fatal exception: {0}".format(str(err))) from err
        except Exception:
            # inspired from https://github.com/vllm-project/vllm/blob/5fbbfe9a4c13094ad72ed3d6b4ef208a7ddc0fd7/vllm/entrypoints/launcher.py#L97-L107
            # in some cases, the engine may die while handling a request but the fatal error is not properly propagated
            # so when a query fails, we check on the engine
            if self.engine_client.errored and not self.engine_client.is_running:
                raise FatalException("Engine client failed")
            raise

    @staticmethod
    def _resolve_lora_path(lora_path):
        """
        Can be removed when the minimum vLLM requirement is >=0.5.3

        Prior to vLLM 0.5.3, LoRA adapters were only supported if downloaded to the local machine (not HF model IDs).
        This method replicates the behaviour of vLLM 0.5.3, downloading the LoRA adapter from HF if needed
        https://github.com/vllm-project/vllm/pull/6234/files#diff-7c04dc096fc35387b6759bdc036f747fd9d8cf21bb7b9f2f69b2d57492b59ba1R114-R154
        """
        from huggingface_hub import snapshot_download

        if os.path.isabs(lora_path):
            return lora_path

        if lora_path.startswith('~'):
            return os.path.expanduser(lora_path)

        if os.path.exists(lora_path):
            return os.path.abspath(lora_path)

        # If the path doesn't exist locally, assume it's a Hugging Face repo.
        logger.info("Downloading LoRA adapter from HuggingFace, model id: {hf_id}".format(hf_id=lora_path))
        return snapshot_download(repo_id=lora_path)
