
import logging
import textwrap

from transformers import PretrainedConfig

logger = logging.getLogger(__name__)


def _is_unquantized_expected_model_or_fine_tuned(
        model_config: PretrainedConfig,
        expected_model_name: str,
        expected_model_architecture: str,
        expected_max_position_embeddings: int,
        expected_num_hidden_layers: int,
        expected_num_attention_heads: int,
        expected_hidden_size: int,
        expected_intermediate_size: int
) -> bool:
    logger.info(f"Checking if running un-quantized {expected_model_name} or fine-tuned variant")

    # Check whether the model has the expected architecture
    architecture = model_config.architectures[0]
    logger.info(f"Model architecture: {architecture}")
    if architecture != expected_model_architecture:
        logger.info(f"Model does not use architecture {expected_model_architecture}")
        return False

    # Check whether the model is quantized
    if (
        hasattr(model_config, "quantization_config")
        and model_config.quantization_config is not None
    ):
        # Model is quantized, maybe it fits!
        logger.info("Model is quantized")
        return False

    # Inspect the model config to find out
    if (
        hasattr(model_config, "text_config")
        and model_config.text_config is not None
    ):
        logger.info("Model has a dedicated config for text generation (likely a text/vision model)")
        max_position_embeddings = model_config.text_config.get("max_position_embeddings")
        num_hidden_layers = model_config.text_config.get("num_hidden_layers")
        num_attention_heads = model_config.text_config.get("num_attention_heads")
        hidden_size = model_config.text_config.get("hidden_size")
        intermediate_size = model_config.text_config.get("intermediate_size")

    elif (
        hasattr(model_config, "max_position_embeddings")
        and hasattr(model_config, "num_hidden_layers")
        and hasattr(model_config, "num_attention_heads")
        and hasattr(model_config, "hidden_size")
        and hasattr(model_config, "intermediate_size")
    ):
        logger.info("Model as a top level config")
        max_position_embeddings = model_config.max_position_embeddings
        num_hidden_layers = model_config.num_hidden_layers
        num_attention_heads = model_config.num_attention_heads
        hidden_size = model_config.hidden_size
        intermediate_size = model_config.intermediate_size

    else:
        max_position_embeddings = None
        num_hidden_layers = None
        num_attention_heads = None
        hidden_size = None
        intermediate_size = None

    # Attributes extraction done, let's match them
    if (
        max_position_embeddings is None
        or num_hidden_layers is None
        or num_attention_heads is None
        or hidden_size is None
        or intermediate_size is None
    ):
        logger.info(f"Model is probably not a {expected_model_name} or alike because of missing attributes in the config")
        return False

    logger.info(textwrap.dedent(
        f"""
        Max position embeddings: {max_position_embeddings}
        Num hidden layers: {num_hidden_layers}
        Num attention heads: {num_attention_heads}
        Hidden size: {hidden_size}
        Intermediate size: {intermediate_size}
        """
    )
    )

    if (
        max_position_embeddings == expected_max_position_embeddings
        and num_hidden_layers == expected_num_hidden_layers
        and num_attention_heads == expected_num_attention_heads
        and hidden_size == expected_hidden_size
        and intermediate_size == expected_intermediate_size
    ):
        logger.info(f"Model is {expected_model_name} or alike (un-quantized or fine-tuned variant)")
        return True

    logger.info(f"Model is not a {expected_model_name} or alike")
    return False


def _is_architecture(model_config: PretrainedConfig, expected_architecture: str) -> bool:
    logger.info(f"Checking if running model with architecture {expected_architecture}")
    architecture = model_config.architectures[0]
    logger.info(f"Model architecture: {architecture}")
    if architecture != expected_architecture:
        logger.info(f"Model does not use {expected_architecture} architecture")
        return False
    return True
