import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor

import torch

from typing import List
from typing import Dict
from typing import AsyncIterator

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline

from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.base.async_link import FatalException
from dataiku.huggingface.pipeline import QuantizationMode
from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.torch_utils import best_supported_dtype
from dataiku.huggingface.types import ProcessSinglePromptCommand, ProcessSinglePromptResponseTextFull, UsageData
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.env_collector import extract_model_architecture, extract_weight_quantization_method

logger = logging.getLogger(__name__)


class ModelPipelineTextGeneration(ModelPipelineBatching[ProcessSinglePromptCommand, ProcessSinglePromptResponseText]):
    """Base class for the text generation pipelines to implement."""

    def __init__(self, model_path, model_kwargs, batch_size, *args, **kwargs):
        super().__init__(batch_size=batch_size)
        self._initialize_pipeline(model_path, model_kwargs, *args, **kwargs)

    # default implementation, can be overridden in children classes
    def _initialize_pipeline(self, model_path, model_kwargs, hf_handling_mode):
        logger.info("CUDA available: {cuda_available}".format(cuda_available=torch.cuda.is_available()))
        kwargs = {} if model_kwargs is None else model_kwargs

        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=kwargs.get("trust_remote_code") or False)
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", **kwargs)
        self.task = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")

        self._fixup_tokenizer_if_needed(self.task)
        self.chat_template_renderer = ChatTemplateRenderer(self.task.tokenizer, hf_handling_mode)
        pretrained_config = model.config
        self.used_engine = "transformers"
        self.model_tracking_data["task"] = "text-generation"
        self.model_tracking_data["adapter"] = "none"
        self.model_tracking_data["model_architecture"] = extract_model_architecture(pretrained_config)
        self.model_tracking_data["weights_quantization"] = extract_weight_quantization_method(pretrained_config)

    async def initialize_model(self):
        prompt = "Explain in simple terms what Generative AI is and why prompts matter for it"
        request = {'query': {'messages': [{'role': 'user', 'content': prompt}]},
                   'settings': {'topP': 1.0, 'topK': 50, 'temperature': 1.0, 'maxOutputTokens': 256}}
        logger.info("Testing engine with basic prompt: {prompt}".format(prompt=prompt))
        inputs = self._get_inputs([request])
        params = self._get_params(request)
        params['batch_size'] = len(inputs)
        with ThreadPoolExecutor(1) as executor:
            result = await asyncio.get_running_loop().run_in_executor(executor, self._run_inference, inputs, params)
            if len(result) == 0:
                raise FatalException("Something went wrong at initialization. Engine did not return any result for basic prompt: {prompt}".format(prompt=prompt))
            logger.info("Test prompt executed successfully: {result}".format(result=result))

    # default implementation, can be overridden in children classes
    # valid only when chat_template_renderer is defined
    def _get_inputs(self, requests: List[ProcessSinglePromptCommand]) -> List[str]:
        input_texts = [self.chat_template_renderer.render(request["query"]["messages"]) for request in requests]
        if logging.DEBUG >= logging.root.level:
            logging.debug("Reformatted inputs to: {input_texts}".format(input_texts=input_texts))
        return input_texts

    # default implementation for transformers models, can be overridden in children classes
    # valid only when task is defined
    def _run_inference(self, input_texts: List[str], tf_params: Dict) -> List:
        responses = self.task(input_texts, **tf_params)

        if logging.DEBUG >= logging.root.level:
            logging.debug("Response received from HF model: \n {}".format(responses))

        return responses

    def _get_params(self, request: ProcessSinglePromptCommand) -> Dict:
        params = request.get("settings", {})
        kwargs = {
            "return_full_text": False,
            "max_new_tokens": params.get("maxOutputTokens", 256)
        }
        if "temperature" in params:
            kwargs["temperature"] = params["temperature"]
        if "topK" in params:
            kwargs["top_k"] = params["topK"]
        if "topP" in params:
            kwargs["top_p"] = params["topP"]

        # Decide whether the provided params require sampling to be forced greedy
        if ("temperature" in kwargs and kwargs["temperature"] == 0.0) or \
                ("top_k" in kwargs and kwargs["top_k"] < 1) or \
                ("top_p" in kwargs and kwargs["top_p"] == 0.0):
            logging.info("Forcing greedy decoding")
            kwargs["do_sample"] = False  # temperature, topK, topP are ignored, and we force greedy decoding
            kwargs.pop("temperature", None)  # remove temperature, because some temperature values trigger exceptions even though temperature shouldn't affect greedy decoding
            kwargs.pop("top_k", None)  # remove top_k and top_p, because HF gives a warning when they are set with do_sample=False
            kwargs.pop("top_p", None)
        else:
            kwargs["do_sample"] = True
        return kwargs

    async def run_single_async(self, request: ProcessSinglePromptCommand) -> AsyncIterator[ProcessSinglePromptResponseText]:
        async for resp in super().run_single_async(request):
            if request.get("stream"):
                yield { "chunk": { "text": resp["text"] } }
                yield { "footer": { "usage": { **resp["usage"] } } }
            else:
                yield resp

    def _parse_response(self, response, request: ProcessSinglePromptCommand) -> ProcessSinglePromptResponseTextFull:
        # HF doc:
        # Returns one of the following dictionaries (cannot return a combination of both generated_text and generated_token_ids):
        # generated_text (str, present when return_text=True) — The generated text.

        output_text = response[0]["generated_text"]
        input_text = self.chat_template_renderer.render(request["query"]["messages"])

        usage: UsageData = {
            "promptTokens": len(self.task.tokenizer.encode(input_text)),
            "completionTokens": len(self.task.tokenizer.encode(output_text)),
        }

        return { "text": output_text, "usage": usage }

    @staticmethod
    def _fixup_tokenizer_if_needed(pipeline):
        # Fix an issue in transformers: "Pipeline with tokenizer without pad_token cannot do batching"
        # pad_token property is used to set the pad_token_id used during batching see HF codebase:
        # https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/tokenization_utils_base.py#L1210

        if pipeline.tokenizer.pad_token_id is None:
            try:  # [sc-206858] fetch the EOS token id from the tokenizer config
                eos_token_id = pipeline.tokenizer.eos_token_id
            except Exception as e:
                logger.warning(f"Accessing tokenizer.eos_token_id failed with error: {e}")
                eos_token_id = None

            if eos_token_id is None:  # last chance, fetch from the model config
                eos_token_id = pipeline.model.config.eos_token_id

            # when fetched from the model config, can be a list
            # take the first that comes - has no impact for inference
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]

            pipeline.tokenizer.pad_token_id = eos_token_id

    @staticmethod
    def _get_quantization_config(quantization_mode):
        """
        :param quantization_mode: QuantizationMode
        :rtype: BitsAndBytesConfig
        """
        if quantization_mode == QuantizationMode.Q_8BIT:
            logger.info("Using 8 bit quantization")
            return BitsAndBytesConfig(load_in_8bit=True)
        elif quantization_mode == QuantizationMode.Q_4BIT:
            logger.info("Using 4 bit quantization")
            return BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=best_supported_dtype(),
            )
        elif quantization_mode == QuantizationMode.NONE:
            logger.info("Using no quantization")
            return None
        else:
            logger.warning("Unknown quantization mode: '{quantization_mode}'. Using no quantization.".format(quantization_mode=quantization_mode))
            return None

    @staticmethod
    def set_dtype_and_quantization_kwargs(model_kwargs, quantization_mode):
        """
        Sets the `quantization_config` and `torch_dtype` params in `model_kwargs`.
        Also returns `model_kwargs`, in case it was None when passed in.
        """
        if model_kwargs is None:
            model_kwargs = {}
        quantization_config = ModelPipelineTextGeneration._get_quantization_config(quantization_mode)
        if quantization_config is not None:
            model_kwargs["quantization_config"] = quantization_config
        else:
            model_kwargs["torch_dtype"] = best_supported_dtype()
        return model_kwargs
