from transformers import PretrainedConfig
from typing import Optional, Type

from dataiku.huggingface.types import ModelSettings
from .model_sniffing_utils import _is_architecture
from .model_sniffing_utils import _is_unquantized_expected_model_or_fine_tuned
from .model_sniffing_utils import  _has_mistral_format
from .model_params_loader import ModelParamsLoader
from .model_params_loader import Llama318BModelParamsLoader
from .model_params_loader import Llama323BModelParamsLoader
from .model_params_loader import Llama3211BVisionModelParamsLoader
from .model_params_loader import Llama3290BVisionModelParamsLoader
from .model_params_loader import Llava16Mistral7BVisionModelParamsLoader
from .model_params_loader import Mistral7BModelParamsLoader
from .model_params_loader import MistralNemo12BModelParamsLoader
from .model_params_loader import MistralFormatParamsLoader
from .model_params_loader import Phi35VisionModelParamsLoader


def get_model_params_loader(model_to_load: str, model_config: PretrainedConfig, model_settings: ModelSettings, with_image_input: bool, use_dss_model_cache: bool, lora_path: Optional[str]) -> ModelParamsLoader:
    model_params_loader_class = _get_model_params_loader_class(model_to_load, model_config, use_dss_model_cache)
    return model_params_loader_class(model_config, model_settings, with_image_input, lora_path)


def _get_model_params_loader_class(model_to_load: str, model_config: PretrainedConfig, use_dss_model_cache: bool) -> Type[ModelParamsLoader]:
    if _is_unquantized_expected_model_or_fine_tuned(
        model_config, "Mistral 7B", "MistralForCausalLM", 32768, 32, 32, 4096, 14336
    ):
        return Mistral7BModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Llama 3.1 8B", "LlamaForCausalLM", 131072, 32, 32, 4096, 14336
    ):
        return Llama318BModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Llama 3.2 3B", "LlamaForCausalLM", 131072, 28, 24, 3072, 8192
    ):
        return Llama323BModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Llama 3.2 11B Vision", "MllamaForConditionalGeneration", 131072, 40, 32, 4096, 14336
    ):
        return Llama3211BVisionModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Llama 3.2 90B Vision", "MllamaForConditionalGeneration", 131072, 100, 64, 8192, 28672
    ):
        return Llama3290BVisionModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Mistral Nemo 12B", "MistralForCausalLM", 131072, 40, 32, 5120, 14336
    ):
        return MistralNemo12BModelParamsLoader

    if _is_unquantized_expected_model_or_fine_tuned(
            model_config, "Phi 3.5 vision", "Phi3VForCausalLM", 131072, 32, 32, 3072, 8192
    ):
        return Phi35VisionModelParamsLoader

    if _is_architecture(model_config, "LlavaNextForConditionalGeneration"):
        return Llava16Mistral7BVisionModelParamsLoader

    if _has_mistral_format(model_to_load, use_dss_model_cache):
        return MistralFormatParamsLoader

    return ModelParamsLoader
