import logging
from copy import deepcopy
from enum import Enum
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from typing import List, Optional, Union

try:
    # for type checking
    from vllm.config import VllmConfig
except Exception:
    pass

from dataiku.huggingface.types import ChatMessage

logger = logging.getLogger(__name__)


class SystemPromptSupport(Enum):
    NONE = 1    # Model does not support system prompts
    SINGLE = 2  # Model supports one system prompt at the start of the messages
    MULTI = 3   # Model supports more than one system prompt message


# These templates are possibly incorrect , but it doesn't matter much since only old models don't come with a chat template. This is pure legacy code.
FALCON_JINJA_TEMPLATE = "{% for message in messages %}{% if message.role == 'system' %}>>CONTEXT<<{{ message.content }}\n{% elif message.role == 'user' %}>>QUESTION<<{{ message.content }}\n{% else %}>>ANSWER<<{{ message.content }}\n{% endif %}{% endfor %}"
MPT_DOLLY_JINJA_TEMPLATE = "{% for message in messages %}{% if message.role == 'system' %}{{ message.content }}\n{% elif message.role == 'user' %}### Instruction:\n{{ message.content }}\n{% else %}### Response:\n{{ message.content }}\n{% endif %}{% endfor %}"


class ChatTemplateRenderer(object):
    """
    This class is responsible for rendering the chat template prompt from a list of messages.

    Recent instruct/chat models generally contain a jinja2 chat template in their tokenizer config. This is the preferred way to render the prompt and it is
    compatible with both VLLM and transformers packages. It is also more future proof as it allows to change the prompt format without changing the code.

    However:
    - Older models may not contain a chat template, and so we need to continue to render the prompt ourselves (e.g. Falcon, Dolly)
    - Some chat templates do not support system prompts (e.g. Mistral instruct). If we detect this, we merge system + user prompts as a workaround
    - Many chat templates do not support multiple messages of the same role next to each other. If we detect this, we merge adjacent messages of the same role
    """
    def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], hf_handling_mode: str, supports_message_parts: bool = False, chat_template_override: str = "", vllm_config: Optional["VllmConfig"] = None):
        self.hf_handling_mode = hf_handling_mode
        self.tokenizer = tokenizer
        self.chat_template_override = chat_template_override
        self.vllm_config = vllm_config # only for vllmPipeline

        if self._get_chat_template() or (self.vllm_config is not None and self.vllm_config.load_config.load_format == "mistral"):
            self.supports_adjacent_same_role_messages = self._sniff_adjacent_messages_same_role_support()
            self.supports_system_prompt_with_image_input = self._sniff_system_prompt_with_image_input_support()
            self.system_prompt_support = self._sniff_system_prompt_support()
            self.supports_tool = self._sniff_tool_support()
        else:
            self.supports_adjacent_same_role_messages = True
            self.system_prompt_support = SystemPromptSupport.MULTI
            self.supports_system_prompt_with_image_input = True
            self.supports_tool = False

        if self.system_prompt_support == SystemPromptSupport.NONE:
            logger.info("System prompt is not supported by the chat template. Converting any system messages into user messages.")
        elif self.system_prompt_support == SystemPromptSupport.SINGLE:
            logger.info("Single system prompt is supported by the chat template. Converting any subsequent system messages into user messages.")
        else:  # MULTI
            logger.info("Multiple system prompts are supported by the chat template.")

        if self.supports_adjacent_same_role_messages:
            logger.info("Adjacent messages of the same role are supported by the chat template.")
        else:
            logger.info("Adjacent messages of the same role are not supported by the chat template. Merging any adjacent messages of the same role.")

        if self.supports_system_prompt_with_image_input:
            logger.info("System prompt with image input is supported by the chat template.")
        else:
            logger.info("System prompt with image input is not supported by the chat template. If both are provided, all system messages will be converted into user messages.")

        self.supports_message_parts = supports_message_parts

    def rearrange_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
        """
        Rearrange the messages to maximize compatibility with the model's chat template.
        """
        if not self.supports_message_parts:
            messages = self._collapse_message_parts(messages)

        if self.system_prompt_support == SystemPromptSupport.NONE:
            messages = self._convert_message_role(messages, "system", "user")
        elif self.system_prompt_support == SystemPromptSupport.SINGLE:
            messages = self._convert_extra_system_message_to_user(messages)

        if (
            _has_system_message(messages)
            and _has_message_with_image_part(messages)
            and not self.supports_system_prompt_with_image_input
        ):
            messages = self._convert_message_role(messages, "system", "user")

        if not self.supports_adjacent_same_role_messages:
            messages = self._collapse_adjacent_same_role_messages(messages)

        return messages

    def _get_chat_template(self, tools: Optional[List] = None) -> Optional[str]:
        """
        Return the chat template (jinja2 format) from HF or explicitly set by the user, without fallbacks.
        """
        tools = tools or None  # ensure we replace [] with None for goog behaviour of vllm/transformers method calls

        if self.chat_template_override:
            # When using a Mistral tokenizer, the chat template cannot be overriden. It used to be a warning
            # but it is an error since https://github.com/vllm-project/vllm/pull/26358.
            if isinstance(self.tokenizer, MistralTokenizer):
                logger.warning("Custom chat template cannot be used with mistral tokenizer")
                return None
            logger.info("Using user's custom chat template.")
            return self.chat_template_override

        hf_chat_template = None

        ## If it is a Mistral Tokenizer, the chat_template is never used by vllm internals cf https://docs.vllm.ai/en/latest/api/vllm/transformers_utils/tokenizers/mistral.html
        if isinstance(self.tokenizer, MistralTokenizer):
            return None

        if self.vllm_config is not None:
            from vllm.entrypoints.chat_utils import resolve_hf_chat_template
            hf_chat_template = resolve_hf_chat_template(
                tokenizer=self.tokenizer,
                chat_template=None,
                tools=tools,
                model_config=self.vllm_config.model_config)

        if hf_chat_template is None and self.tokenizer.chat_template:
            hf_chat_template = self.tokenizer.get_chat_template(None, tools)

        if hf_chat_template:
            logger.info("Using chat template from HF")
        else:
            logger.info("No chat template found on HF or provided by user")

        return hf_chat_template

    def get_chat_template(self, tools: Optional[List] = None) -> str:
        """
        Return the chat template (jinja2 format) used to render the prompt.
        """

        # Fallbacks
        if self.hf_handling_mode == "TEXT_GENERATION_DOLLY":
            logger.info("Using Dolly template")
            return MPT_DOLLY_JINJA_TEMPLATE
        if self.hf_handling_mode == "TEXT_GENERATION_FALCON":
            logger.info("Using Falcon template")
            return FALCON_JINJA_TEMPLATE
        if self.hf_handling_mode == "TEXT_GENERATION_MPT":
            logger.info("Using MPT template")
            return MPT_DOLLY_JINJA_TEMPLATE

        return self._get_chat_template(tools)

    def render(self, messages: List[ChatMessage]):
        """
        This method should become legacy once we move to vLLM's OpenAI API.
        """
        return self.apply_chat_template(
            self.rearrange_messages(messages),
            chat_template=self.get_chat_template(),
        )

    def _sniff_adjacent_messages_same_role_support(self):
        """
        :return: True if the model supports having multiple messages of the same role next to each other.
                 Some models require a strict order of messages: system/user/assistant/user/assistant/...
        :rtype: Boolean
        """
        self._tokenizer_chat_template_sanity_check()

        try:
            prompt = self.apply_chat_template(
                [
                    {"role": "user", "content": "hello"},
                    {"role": "user", "content": "here is my question"},
                    {"role": "assistant", "content": "hi"},
                    {"role": "assistant", "content": "here is my answer"},
                ],
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.get_chat_template(),
            )
            return ("here is my question" in prompt) and ("here is my answer" in prompt) and ("hello" in prompt)
        except Exception:
            return False

    def _sniff_system_prompt_support(self):
        """
        :return: The type of system prompt supported by the model.
        :rtype: SystemPromptSupport
        """
        if self._sniff_multiple_system_prompt_support():
            return SystemPromptSupport.MULTI
        elif self._sniff_single_system_prompt_support():
            return SystemPromptSupport.SINGLE
        else:
            return SystemPromptSupport.NONE

    def _tokenizer_chat_template_sanity_check(self):
        """
        :raise Exception: If the model does not have a tokenizer with a chat template that we know how to use.
        """
        # Sanity check, should always succeed
        self.apply_chat_template(
            [{"role": "user", "content": "hello"}],
            chat_template=self.get_chat_template(),
        )

    def _sniff_single_system_prompt_support(self):
        """
        :return: True if the model supports a system prompt at the start of the messages.
        """
        self._tokenizer_chat_template_sanity_check()

        try:
            # Check if the chat template supports system prompt (e.g. llama ok, mistral NOK)
            return "be happy" in self.apply_chat_template(
                [
                    {"role": "system", "content": "be happy"},
                    {"role": "user", "content": "hello"},
                ],
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.get_chat_template(),
            )
        except Exception:
            return False

    def _sniff_multiple_system_prompt_support(self):
        """
        :return: True if the model supports more than one system prompt, interleaved with user messages.
        """
        self._tokenizer_chat_template_sanity_check()

        try:
            prompt = self.apply_chat_template(
                [
                    {"role": "system", "content": "be happy"},
                    {"role": "user", "content": "hello"},
                    {"role": "assistant", "content": "hi"},
                    {"role": "system", "content": "second one"},
                ],
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.get_chat_template(),
            )
            return ("be happy" in prompt) and ("second one" in prompt)
        except Exception:
            return False

    def _sniff_system_prompt_with_image_input_support(self):
        """
        :return: True if the model supports a system prompt with image input.
        """
        self._tokenizer_chat_template_sanity_check()

        try:
            prompt = self.apply_chat_template(
                [
                    {"role": "system", "content": "be happy"},
                    {"role": "user", "content": [
                        {"type": "text", "text": "describe the image"},
                        {"type": "image"},
                    ]},
                ],
                tokenize=False,
                add_generation_prompt=True,
                chat_template=self.get_chat_template(),
            )

            return "describe the image" in prompt and "be happy" in prompt
        except Exception:
            return False

    def _sniff_tool_support(self):
        """
        :return: True if the model supports tools.
        """
        self._tokenizer_chat_template_sanity_check()
        tool = {
            'type': 'function',
            'function': {
                'name': 'DKU_tool_sniffer',
                'description': 'A function that allows to know if tool calling is supported by the current model.',
                'parameters': {
                    'type': 'object',
                    'properties': {'x': {'type': 'integer', 'description': 'One number'}},
                    'required': ['x']
                }
            }
        }
        prompt = self.apply_chat_template(
            [
                {"role": "user", "content": "hello world"},
            ],
            chat_template=self.get_chat_template(tools=[tool]),
            tools=[tool],
        )
        return "DKU_tool_sniffer" in prompt

    def apply_chat_template(self, *args, **kwargs) -> str:
        kwargs["tokenize"] = True
        kwargs["add_generation_prompt"] = True
        token_ids = self.tokenizer.apply_chat_template(*args, **kwargs)

        if isinstance(token_ids, str):
            # Defensive check:
            # 'tokenize=True' instructs the tokenizer to return a list of token ids
            # 'tokenize=False' instructs the tokenizer to return a string
            #
            # Transformers tokenizers are generally well-behaved and respect this contract. vLLM's tokenizers
            # are sometimes delegating to transformers tokenizers, but not always. In particular, some vLLM
            # tokenizers (e.g. MistralTokenizer) ignore the 'tokenize' argument and always return token ids.
            #
            # I have seen no case where 'tokenize=True' returns a string, but just in case, we handle this here.
            return token_ids

        return self.tokenizer.decode(token_ids)

    @staticmethod
    def _convert_message_role(messages: List[ChatMessage], from_role, to_role) -> List[ChatMessage]:
        """
        Copy of src/main/java/com/dataiku/dip/llm/online/LLMChatMessageUtils.java::convertMessageRole
        """
        new_messages = []
        for message in messages:
            new_message = deepcopy(message)
            if new_message["role"] == from_role:
                new_message["role"] = to_role
            new_messages.append(new_message)
        return new_messages

    @staticmethod
    def _convert_extra_system_message_to_user(messages: List[ChatMessage]) -> List[ChatMessage]:
        """
        Copy of src/main/java/com/dataiku/dip/llm/online/LLMChatMessageUtils.java::convertExtraSystemMessageToUser
        """
        new_messages = []
        first_system_message = True
        for message in messages:
            new_message = deepcopy(message)

            if message["role"] != "system":
                first_system_message = False
            if message["role"] == "system" and not first_system_message:
                new_message["role"] = "user"

            new_messages.append(new_message)
        return new_messages

    @staticmethod
    def _collapse_adjacent_same_role_messages(messages: List[ChatMessage]) -> List[ChatMessage]:
        """
        Copy of src/main/java/com/dataiku/dip/llm/online/LLMChatMessageUtils.java::collapseAdjacentSameRoleMessages
        """
        new_messages = []
        current_message = None
        for message in messages:

            if current_message is None:
                current_message = deepcopy(message)
            elif current_message["role"] != message["role"]:
                new_messages.append(current_message)
                current_message = deepcopy(message)
            else:
                if "content" in current_message:
                    text = current_message.pop("content")
                    current_message["parts"] = [{"type": "TEXT", "text": text}] if text else []
                if message.get("content"):
                    current_message["parts"].extend([{"type": "TEXT", "text": message["content"] }])
                if message.get("parts"):
                    current_message["parts"].extend(message["parts"])

        if current_message is not None:
            new_messages.append(current_message)

        return new_messages

    @staticmethod
    def _collapse_message_parts(messages: List[ChatMessage]) -> List[ChatMessage]:
        """
        Collapse text message parts into the message `content`.
        @raise Exception if any non-text parts are in the messages.
        """
        PARTS_JOINING_SEPARATOR = "\n"
        new_messages = []
        for message in messages:
            new_message = deepcopy(message)
            if "parts" in message and message["parts"] is not None:
                text_only = all(part["type"] == "TEXT" for part in message["parts"])
                if not text_only:
                    raise Exception("Non-text message parts not supported for local HuggingFace models")
                new_message["content"] = PARTS_JOINING_SEPARATOR.join(part["text"] for part in message["parts"] if ("text" in part and part["text"] is not None))
            new_messages.append(new_message)
        return new_messages


def _has_system_message(messages: List[ChatMessage]) -> bool:
    return any(m["role"] == "system" for m in messages)


def _has_message_with_image_part(messages: List[ChatMessage]) -> bool:
    return any(_has_image_part(m) for m in messages)


def _has_image_part(message: ChatMessage) -> bool:
    parts = message.get("parts")
    return parts is not None and any(
        p["type"] in ["IMAGE_INLINE", "IMAGE_URI"]
        for p in parts
    )
