import os
import logging
import sys
import torch
import transformers

from typing import AsyncIterator

from dataiku.base.async_link import FatalException
from dataiku.base.utils import package_is_at_least
from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.pipeline import ModelPipeline, QuantizationMode
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.types import ToolSettings
from dataiku.huggingface.types import ChatTemplateSettings
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.env_collector import extract_model_architecture, extract_weight_quantization_method
from dataiku.huggingface.vllm_backend.model_params import get_model_params_loader

logger = logging.getLogger(__name__)

IS_MAC = sys.platform == "darwin"


class ModelPipelineTextGenerationVLLM(ModelPipeline[ProcessSinglePromptCommand, ProcessSinglePromptResponseText]):
    def __init__(
            self,
            hf_handling_mode,
            model_name_or_path,
            base_model_name_or_path,  # will be None if not an adapter model
            model_settings: ModelSettings,
            with_image_input
    ):
        super().__init__()

        try:
            import vllm._moe_C
        except Exception:
            logger.exception("Some VLLM features like MoE models may not work on this OS, make sure to use Almalinux 9.")

        if base_model_name_or_path is not None:
            # LoRa adapter model
            model_to_load = base_model_name_or_path
            self.lora_path = self._resolve_lora_path(model_name_or_path)
            logger.info(f"Loading LoRA adapter from local path: {self.lora_path}")
        else:
            # Not an adapter model
            model_to_load = model_name_or_path
            self.lora_path = None

        logger.info("Loading model config")
        transformers_model_config = transformers.PretrainedConfig.from_pretrained(model_to_load)
        logger.info("Model config loaded")

        self.model_tracking_data["model_architecture"] = extract_model_architecture(transformers_model_config)
        self.model_tracking_data["used_engine"] = "vllm"
        self.model_tracking_data["task"] = "text-generation"
        self.model_tracking_data["adapter"] = "lora" if base_model_name_or_path is not None else "none"
        self.model_tracking_data["weights_quantization"] = extract_weight_quantization_method(transformers_model_config)

        model_params_loader = get_model_params_loader(transformers_model_config, model_settings, with_image_input, self.lora_path)

        load_params = {
            "disable_log_stats": False,  # Enable logging of performance metrics
            "model": model_to_load,
            "ignore_patterns": [
                "original/**/*",  # avoid repeated downloading of llama's checkpoint
                "consolidated*.safetensors"  # filter out Mistral-format weights
            ],
            **model_params_loader.build_params()
        }
        logger.info(f"Loading model with args {load_params}")

        from vllm import AsyncEngineArgs, AsyncLLMEngine
        self.engine_client = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**load_params))
        vllm_model_config = self.engine_client.engine.get_model_config()
        logger.info("Model loaded")

        # -------------------------------
        # CHAT TEMPLATE
        # having tool template here does not mean the model necessarily supports tool call (mistral 7b v0.1 does not for example)
        chat_template_settings: ChatTemplateSettings = model_settings.get("chatTemplateSettings", {})
        chat_template_override: str = chat_template_settings.get('chatTemplate') or "" if chat_template_settings.get('overrideChatTemplate', False) else ""

        logger.info("Loading tokenizer")
        tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=model_params_loader.trust_remote_code or False)
        logger.info("Tokenizer loaded")

        self.chat_template_renderer = ChatTemplateRenderer(
            tokenizer=tokenizer,
            hf_handling_mode=hf_handling_mode,
            supports_message_parts=True,
            chat_template_override=chat_template_override,
            vllm_model_config=vllm_model_config
        )
        # -------------------------------

        # -------------------------------
        # TOOLS
        tool_settings: ToolSettings = model_settings.get("toolSettings", {})
        enable_tools: bool = tool_settings.get('enableTools', False)
        self.tool_parser: str = tool_settings.get('toolParser') or ""

        if enable_tools:
            if not self.tool_parser:
                logger.warning("Tools disabled: please specify a tool parser")
                self.tools_supported = False
            elif self.chat_template_renderer.supports_tool:
                logger.info("Model supports tools")
                logger.info(f"Using tool parser {self.tool_parser}")
                self.tools_supported = True
            else:
                logger.warning("Model does not support tools")
                self.tools_supported = False
        else:
            logger.info("Tools disabled for this model")
            self.tools_supported = False
        # -------------------------------

        # -------------------------------
        # GUIDED GENERATION
        self.guided_decoding_backend = model_params_loader.guided_decoding_backend
        self.enable_json_constraints_in_prompt = model_settings.get('enableJsonConstraintsInPrompt', True)

        logger.info(f"Engine used for guided decoding: '{self.guided_decoding_backend}'")
        logger.info(f"Json constraints injected in prompt: {self.enable_json_constraints_in_prompt}")
        # -------------------------------

    async def initialize_model(self):
        from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
        from vllm.entrypoints.openai.serving_models import OpenAIServingModels, LoRAModulePath, BaseModelPath
        from dataiku.huggingface.vllm_backend.oai_mappers import generate_dummy_dss_request

        if self.lora_path:
            # 'base_model' is directly the base model, but it's not queried
            base_model_path = BaseModelPath(name="base_model", model_path="fake_path")
        else:
            # 'model' is directly the ID of the queried model
            base_model_path = BaseModelPath(name="model", model_path="fake_path")

        openai_serving_models = OpenAIServingModels(
            self.engine_client,
            await self.engine_client.get_model_config(),
            base_model_paths=[base_model_path],
            lora_modules=(
                # Add a model ID 'model' that will be queried when LoRA is enabled
                [LoRAModulePath(name="model", path=self.lora_path)]
                if self.lora_path
                else []
            ),
            prompt_adapters=None
        )

        if self.lora_path:
            await openai_serving_models.init_static_loras()

        # Create the OpenAI Chat API server of vLLM
        self.openai_server = OpenAIServingChat(
            engine_client=self.engine_client,
            models=openai_serving_models,
            model_config=await self.engine_client.get_model_config(),
            response_role="assistant",
            request_logger=None,
            chat_template=self.chat_template_renderer.get_chat_template(),
            chat_template_content_format="auto",
            **({
                'enable_auto_tools': True,
                'tool_parser': self.tool_parser,
            } if self.tools_supported else {})
        )

        prompt = "Explain in simple terms what Generative AI is and why prompts matter for it"
        logger.info(
            "Testing engine with basic prompt: {input_text}".format(input_text=prompt)
        )
        request = generate_dummy_dss_request(prompt)

        result = None
        async for result in self.run_single_async(request=request):
            pass

        if not result or "text" not in result or len(result["text"]) == 0:
            raise FatalException(
                "Something went wrong at initialization. Engine did not return any result for basic prompt"
            )
        logger.info("Test prompt executed successfully: {result}".format(result=result["text"]))

    @staticmethod
    def supports_model(hf_handling_mode, model_name_or_path, base_model_name_or_path, model_settings):
        logger.info("Checking if VLLM is supported")
        try:
            from peft import LoraConfig
            adapter_config = LoraConfig.from_pretrained(model_name_or_path)
            logger.info(f"Adapter config with rslora {adapter_config.use_rslora}")
            if adapter_config.use_rslora:
                # Bug-fix in VLLM - SC-219662
                logger.info("RS-LoRa is not supported in VLLM")
                return False
        except ValueError:
            logger.info("Model is not an adapter model")

        if not hf_handling_mode.startswith("TEXT_GENERATION_"):
            logger.info("Handling mode is not TEXT_GENERATION, vLLM not supported")
            return False

        try:
            from vllm.model_executor.models import ModelRegistry
            import vllm
            vllm_version = vllm.__version__
            logger.info("VLLM version: " + vllm_version)
        except ImportError:
            logger.info("VLLM is not installed")
            return False

        expected_vllm_version = "0.9.0.1"
        if not package_is_at_least(vllm, expected_vllm_version):
            raise ValueError(f"Installed version of 'vllm' (version={vllm_version}) is too old, "
                             f"please upgrade to version {expected_vllm_version}")

        try:
            quantization_mode = QuantizationMode[model_settings["quantizationMode"]]
            if quantization_mode == QuantizationMode.Q_8BIT:
                logger.info("Quantization mode {mode} not supported by VLLM".format(mode=quantization_mode))
                return False

            transformers_model_config = transformers.PretrainedConfig.from_pretrained(
                base_model_name_or_path if base_model_name_or_path is not None else model_name_or_path
            )

            if not IS_MAC:
                if not torch.cuda.is_available():
                    # vllm does not support CPU inference yet: https://github.com/vllm-project/vllm/pull/1028
                    logger.info("CUDA is not available, vLLM not supported")
                    return False

                if torch.cuda.device_count() == 0:
                    logger.info("No CUDA device found, vLLM not supported")
                    return False

                # Check compute capability
                for i in range(torch.cuda.device_count()):
                    capability_level = torch.cuda.get_device_capability(i)
                    device_name = torch.cuda.get_device_name(i)
                    if capability_level[0] < 7:
                        logger.info(
                            f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is not supported"
                        )
                        return False
                    else:
                        logger.info(
                            f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is supported"
                        )

            supported_architectures = ModelRegistry.get_supported_archs()
            architecture = transformers_model_config.architectures[0]
            if architecture not in supported_architectures:
                logger.info(
                    "Model architecture {architecture} not supported by VLLM".format(
                        architecture=architecture
                    )
                )
                return False

            if ("MixtralForCausalLM" in transformers_model_config.architectures):
                if quantization_mode != QuantizationMode.NONE:
                    # vllm 0.8.4 fails with Mixtral for inflight bitsandbytes quantization
                    logger.info("Inflight Bitsandbytes quantization with VLLM not supported for Mixtral models")
                    return False
        except Exception:
            logger.exception(
                "Error while checking if VLLM is supported, assuming it is not supported"
            )
            return False

        logger.info("VLLM is supported")
        return True

    @staticmethod
    def _resolve_lora_path(lora_path):
        """
        Can be removed when the minimum vLLM requirement is >=0.5.3

        Prior to vLLM 0.5.3, LoRA adapters were only supported if downloaded to the local machine (not HF model IDs).
        This method replicates the behaviour of vLLM 0.5.3, downloading the LoRA adapter from HF if needed
        https://github.com/vllm-project/vllm/pull/6234/files#diff-7c04dc096fc35387b6759bdc036f747fd9d8cf21bb7b9f2f69b2d57492b59ba1R114-R154
        """
        from huggingface_hub import snapshot_download

        if os.path.isabs(lora_path):
            return lora_path

        if lora_path.startswith('~'):
            return os.path.expanduser(lora_path)

        if os.path.exists(lora_path):
            return os.path.abspath(lora_path)

        # If the path doesn't exist locally, assume it's a Hugging Face repo.
        logger.info("Downloading LoRA adapter from HuggingFace, model id: {hf_id}".format(hf_id=lora_path))
        return snapshot_download(repo_id=lora_path)

    async def run_single_async(
            self, request: ProcessSinglePromptCommand
    ) -> AsyncIterator[ProcessSinglePromptResponseText]:
        from vllm.engine.async_llm_engine import AsyncEngineDeadError
        from dataiku.huggingface.vllm_backend.oai_mappers import (
            dss_to_oai_request,
            oai_to_dss_response,
        )
        request_id = request["id"]
        try:
            logger.info("Start prompt {request_id}".format(request_id=request_id))
            assert self.openai_server

            oai_request = dss_to_oai_request(
                request,
                self.chat_template_renderer,
                self.enable_json_constraints_in_prompt,
                self.guided_decoding_backend,
                self.tools_supported,
            )
            response = await self.openai_server.create_chat_completion(oai_request)
            async for resp_or_chunk in oai_to_dss_response(response):
                yield resp_or_chunk
            logger.info("Done prompt {request_id}".format(request_id=request_id))
        except AsyncEngineDeadError as err:
            raise FatalException("Fatal exception: {0}".format(str(err))) from err
