import logging
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.env_collector import extract_model_architecture

logger = logging.getLogger(__name__)


class ModelPipelineTextGenerationLora(ModelPipelineTextGeneration):

    def _initialize_pipeline(self, model_name_or_path, model_kwargs, hf_handling_mode, base_model_name_or_path):
        """
        :type model_name_or_path: str
        :type base_model_name_or_path: str
        :type hf_handling_mode: str
        :type model_kwargs: Optional[typing.Dict[str, typing.Any]]
        """

        from peft import PeftConfig, PeftModel  # Import inside method, so it still works even if `peft` is not in the code-env

        logger.info("CUDA available: {cuda_available}".format(cuda_available=torch.cuda.is_available()))
        kwargs = {} if model_kwargs is None else model_kwargs

        logger.info("Loading base model")
        base_model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, device_map="auto", **kwargs)

        logger.info("Loading adapter model")
        self.model = PeftModel.from_pretrained(base_model, model_name_or_path, device_map="auto")
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code") or False)

        self.task = pipeline('text-generation', model=self.model, tokenizer=self.tokenizer, device_map="auto", model_kwargs=kwargs)

        self._fixup_tokenizer_if_needed(self.task)
        self.chat_template_renderer = ChatTemplateRenderer(self.task.tokenizer, hf_handling_mode)

        self.model_tracking_data["task"] = "text-generation"
        self.model_tracking_data["used_engine"] = "transformers"
        self.model_tracking_data["adapter"] = "lora"
        self.model_tracking_data["model_architecture"] = extract_model_architecture(base_model.config)
