import logging
import sys
import torch
import transformers

from dataiku.base.utils import package_is_exactly, package_is_at_least
from dataiku.huggingface.pipeline import QuantizationMode

logger = logging.getLogger(__name__)

IS_MAC = sys.platform == "darwin"


def is_vllm_supported(hf_handling_mode, model_name_or_path, base_model_name_or_path, model_settings, expected_vllm_version):
    logger.info("Checking if VLLM is supported")
    if not (hf_handling_mode.startswith("TEXT_GENERATION_") or (hf_handling_mode == "TEXT_EMBEDDING")):
        logger.info("Handling mode is neither TEXT_GENERATION nor TEXT_EMBEDDING, vLLM not supported")
        return False

    # text embedding is broken with vLLM on current version 0.10.1.1, but it should likely be fixed on version >= 0.11
    if (IS_MAC and hf_handling_mode == "TEXT_EMBEDDING"):
        try:
            import vllm
            if not package_is_at_least(vllm, "0.11.1"):
                logger.info("Text embedding with vLLM is currently not supported on MacOS")
                return False
        except Exception:
            logger.info("vLLM cannot be imported")
            return False

    try:
        from vllm.model_executor.models import ModelRegistry
        import vllm
        vllm_version = vllm.__version__
        logger.info("VLLM version: " + vllm_version)
    except ImportError:
        logger.info("VLLM is not installed")
        return False

    if not package_is_exactly(vllm, expected_vllm_version):
        logger.warning(f"Installed version of 'vllm' (version={vllm_version}) is incompatible, "
                          f"expected version is {expected_vllm_version}. You may experience issues.")

    try:
        quantization_mode = QuantizationMode[model_settings["quantizationMode"]]
        if quantization_mode == QuantizationMode.Q_8BIT:
            logger.info("Quantization mode {mode} not supported by VLLM".format(mode=quantization_mode))
            return False

        if not IS_MAC:
            if not torch.cuda.is_available():
                # vllm does not support CPU inference yet: https://github.com/vllm-project/vllm/pull/1028
                logger.info("CUDA is not available, vLLM not supported")
                return False

            if torch.cuda.device_count() == 0:
                logger.info("No CUDA device found, vLLM not supported")
                return False

            # Check compute capability
            for i in range(torch.cuda.device_count()):
                capability_level = torch.cuda.get_device_capability(i)
                device_name = torch.cuda.get_device_name(i)
                if capability_level[0] < 7:
                    logger.info(
                        f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is not supported"
                    )
                    return False
                else:
                    logger.info(
                        f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is supported by device"
                    )

        logger.info("Checking if VLLM is supported with this model architecture")
        supported_architectures = ModelRegistry.get_supported_archs()
        transformers_model_config = transformers.PretrainedConfig.from_pretrained(
            base_model_name_or_path if base_model_name_or_path is not None else model_name_or_path
        )
        architecture = transformers_model_config.architectures[0]
        if architecture not in supported_architectures:
            logger.info(
                "Model architecture {architecture} not supported by VLLM".format(
                    architecture=architecture
                )
            )
            return False

    except Exception:
        logger.exception(
            "Error while checking if VLLM is supported, assuming it is not supported"
        )
        return False

    logger.info("VLLM is supported")
    return True


async def get_model_config(async_llm):
    if hasattr(async_llm, 'model_config'):
        # vLLM >= 0.11.1
        return async_llm.model_config
    else:
        # vLLM < 0.11.1
        return await async_llm.get_model_config()


async def get_vllm_config(async_llm):
    if hasattr(async_llm, 'vllm_config'):
        # vLLM >= 0.11.1
        return async_llm.vllm_config
    else:
        # vLLM < 0.11.1
        return await async_llm.get_vllm_config()