import asyncio
import json
import logging
import random
import os

from abc import ABC, abstractmethod
from enum import Enum

from typing import AsyncIterator
from typing import Generic
from typing_extensions import TypedDict, NotRequired

from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.types import ProcessSinglePromptResponseTextFull
from dataiku.huggingface.types import ProcessSinglePromptResponseZeroShotClassification
from dataiku.huggingface.types import ProcessSingleEmbeddingCommand
from dataiku.huggingface.types import ProcessSingleEmbeddingResponse
from dataiku.huggingface.types import SingleCommand
from dataiku.huggingface.types import SingleResponse

# Should not import huggingface or transformers here, otherwise would break hf_transfer and transformers_offline
# see https://app.shortcut.com/dataiku/story/190851/make-sure-transformers-offline-is-set-early-enough


logger = logging.getLogger(__name__)


# Keep in sync with java/com/dataiku/dip/connections/HuggingFaceLocalConnection::QuantizationMode
class QuantizationMode(Enum):
    Q_4BIT = 4
    Q_8BIT = 8
    NONE = 0

class ModelTrackingData(TypedDict):
    # Task (text gen, text embedding, etc)
    task: NotRequired[str]
    # Model architecture
    model_architecture: NotRequired[str]
    # Whether a LoRa adapter is used
    adapter: NotRequired[str]
    # Engine configured in HF connection properties
    requested_engine: NotRequired[str]
    # Actual engine used at runtime
    used_engine: NotRequired[str]
    # Quantization method of model weights
    weights_quantization: NotRequired[str]
    # Quantization method configured in HF connection UI
    requested_quantization: NotRequired[str]
    # Model source (e.g. HF Hub, DSS model cache)
    model_source: NotRequired[str]

class ModelPipeline(ABC, Generic[SingleCommand, SingleResponse]):
    """Base class for the other pipeline classes to implement."""

    def __init__(self, *args, **kwargs):
        """Load the model into memory, and perform any required setup."""

        # Pipeline can optionally produce tracking data (collected via WT1)
        # Do NOT store any sensitive information into this field, not even model names.
        self.model_tracking_data: ModelTrackingData = {}

        pass

    @abstractmethod
    async def run_single_async(self, request: SingleCommand) -> AsyncIterator[SingleResponse]:
        raise NotImplementedError
        yield

    @abstractmethod
    async def initialize_model(self):
        pass


class MockModelPipelineTextGenerationOrSummarization(ModelPipeline[ProcessSinglePromptCommand, ProcessSinglePromptResponseText]):
    """Mock pipeline for text generation or text summarization."""

    def __init__(self):
        logger.info("Instantiating mock pipeline for text generation or summarization")
        super().__init__()
        self.request_counter = 0
        self.model_tracking_data["used_engine"] = "mock"

    async def initialize_model(self):
        logger.info("Running test prompt for mocked text generation or summarization pipeline")
        await asyncio.sleep(random.uniform(0, 2))
        logger.info("Test prompt for mocked text generation or summarization pipeline ran successfully")

    async def run_single_async(self, request: ProcessSinglePromptCommand) -> AsyncIterator[ProcessSinglePromptResponseText]:
        request_index = self.request_counter
        self.request_counter = self.request_counter + 1

        await asyncio.sleep(random.uniform(0, 2))
        prompt_start = request["prompt"][:20]
        if request.get("stream"):
            yield { "chunk": { "text": "hello! " } }
            await asyncio.sleep(random.uniform(0, 1))
            yield { "chunk": { "text": "it's a me" } }
            await asyncio.sleep(random.uniform(0, 2))
            yield { "chunk": { "text": " from the other side" } }
            await asyncio.sleep(random.uniform(0, 2))
            yield { "chunk": { "text": f" (request #{request_index})" } }
            await asyncio.sleep(random.uniform(0, 2))
            usage = {"promptTokens": 13, "completionTokens": 22}
            yield { "chunk": { "text": f" [{prompt_start}]" } }
            yield { "footer": { "promptTokens": 13, "completionTokens": 22 } }
        else:
            usage = {"promptTokens": 7, "completionTokens": 19}
            await asyncio.sleep(random.uniform(0, 2))
            yield {"text": f"salut c'est moi de l'autre cote (request #{request_index}) [{prompt_start}]",
                   "usage": usage}


class MockModelPipelineTextClassification(ModelPipeline[ProcessSinglePromptCommand, ProcessSinglePromptResponseTextFull]):
    """Mock pipeline for text model-provided classification."""

    def __init__(self):
        logger.info("Instantiating mock pipeline for text classification")
        super().__init__()
        self.model_tracking_data["used_engine"] = "mock"

    async def run_single_async(self, request: ProcessSinglePromptCommand) -> AsyncIterator[ProcessSinglePromptResponseTextFull]:
        await asyncio.sleep(random.uniform(0, 2))
        labels = ["coucou", "hola", "hello"]
        if request["settings"]["textClassificationOutputMode"] == "ALL":
            yield {"text": json.dumps({label: random.random() for label in labels})}
        else:
            yield {"text": json.dumps({"label": random.choice(labels), "score": random.random()})}

    async def initialize_model(self):
        logger.info("Running test prompt for mocked text classification pipeline")
        await asyncio.sleep(random.uniform(0, 2))
        logger.info("Test prompt for mocked text classification pipeline ran successfully")

class MockModelPipelineZeroShotClassification(ModelPipeline[ProcessSinglePromptCommand, ProcessSinglePromptResponseZeroShotClassification]):
    """Mock pipeline for text user-provided classification."""

    def __init__(self):
        logger.info("Instantiating mock pipeline for zero shot classification")
        super().__init__()
        self.model_tracking_data["used_engine"] = "mock"

    async def run_single_async(self, request: ProcessSinglePromptCommand) -> AsyncIterator[ProcessSinglePromptResponseZeroShotClassification]:
        await asyncio.sleep(random.uniform(0, 2))
        labels = request["settings"]["classLabels"]
        scores = [random.random() for _ in range(len(labels))]
        sorted_zip = sorted(zip(scores, labels))
        sorted_labels = [label for _, label in sorted_zip]
        sorted_scores = [score for score, _ in sorted_zip]
        yield {"classification": {"labels": sorted_labels, "scores": sorted_scores}}

    async def initialize_model(self):
        logger.info("Running test prompt for mocked zero shot classification pipeline")
        await asyncio.sleep(random.uniform(0, 2))
        logger.info("Test prompt for mocked zero shot classification pipeline ran successfully")

class MockModelPipelineEmbedding(ModelPipeline[ProcessSingleEmbeddingCommand, ProcessSingleEmbeddingResponse]):
    """Mock pipeline for text and image embedding."""

    def __init__(self):
        logger.info("Instantiating mock pipeline for embedding")
        super().__init__()
        self.model_tracking_data["used_engine"] = "mock"

    async def initialize_model(self):
        logger.info("Running test prompt for mocked embedding pipeline")
        await asyncio.sleep(random.uniform(0, 2))
        logger.info("Test prompt for mocked embedding pipeline ran successfully")

    async def run_single_async(self, request: ProcessSingleEmbeddingCommand) -> AsyncIterator[ProcessSingleEmbeddingResponse]:
        await asyncio.sleep(random.uniform(0, 2))
        embedding_size = 10
        yield {"embedding": [random.random() for _ in range(embedding_size)]}


def create_mock_pipeline(hf_handling_mode):
    if hf_handling_mode == "ZSC_GENERIC":
        return MockModelPipelineZeroShotClassification()
    elif hf_handling_mode in ["TEXT_CLASSIFICATION_SENTIMENT", "TEXT_CLASSIFICATION_EMOTIONS", "TEXT_CLASSIFICATION_TOXICITY", "TEXT_CLASSIFICATION_PROMPT_INJECTION", "TEXT_CLASSIFICATION_OTHER"]:
        return MockModelPipelineTextClassification()
    elif hf_handling_mode in ["SUMMARIZATION_GENERIC", "SUMMARIZATION_ROBERTA", "T5"] or hf_handling_mode.startswith("TEXT_GENERATION_"):
        return MockModelPipelineTextGenerationOrSummarization()
    elif hf_handling_mode in ["TEXT_EMBEDDING", "IMAGE_EMBEDDING"]:
        return MockModelPipelineEmbedding()
    else:
        raise Exception("Unknown handling mode for mock pipeline (hf_handling_mode={hf_handling_mode})".format(hf_handling_mode=hf_handling_mode))


def create_model_pipeline(hf_handling_mode, hf_model_name, model_name_or_path, base_model_name_or_path, use_dss_model_cache, model_settings, batch_size, untrusted_model, supports_image_inputs):
    """Instantiate the appropriate pipeline class, for the given Hugging Face handling mode."""

    # Untrusted models, like the ones created using a Python fine-tuning recipe, must use safetensors weights
    from dataiku.huggingface.pipeline_text_gen_lora import ModelPipelineTextGenerationLora
    model_kwargs = None
    if untrusted_model:
        files = os.listdir(model_name_or_path)
        if base_model_name_or_path is not None:
            # We are working with an adapter saved model. However, PEFT doesn't enforce safetensors and falls back to .bin.
            # See: https://github.com/huggingface/peft/blob/e6cd24c907565040ee1766a5735afe3d13a71164/src/peft/utils/save_and_load.py#L441
            # Therefore, we make sure that safetensors are there and therefore loaded by default with `PeftModel.from_pretrained`.
            if "adapter_model.safetensors" not in files:
                raise Exception("Failed to use model: adapter weights stored in the safetensors format were not found")
        else:
            # We are working, with a full fine-tuned saved model (python recipe). In that case, we always want to use safetensors.
            if "model.safetensors" not in files:
                raise Exception("Failed to use model: model weights stored in the safetensors format were not found")
            model_kwargs = {"use_safetensors": True}
    if use_dss_model_cache and model_kwargs is not None:
        # If it has already been set, we work with a fully fine-tuned saved model and therefore enforcing safetensors.
        # If base_model_path is present, it is an adapter model
        for fname in os.listdir(base_model_name_or_path or model_name_or_path):
            if (fname == "pytorch_model.bin") or (fname == "pytorch_model.bin.index.json"):
                model_kwargs = {"use_safetensors": False}
                break

    requested_engine = "AUTO"
    requested_quantization = "NONE"
    # Set quantization params and trust_remote_code if needed and handle text engine
    if hf_handling_mode.startswith("TEXT_GENERATION_"):
        from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
        requested_quantization = model_settings["quantizationMode"]
        model_kwargs = ModelPipelineTextGeneration.set_dtype_and_quantization_kwargs(model_kwargs, QuantizationMode[requested_quantization])
        if model_settings.get("trustRemoteCode"):
            model_kwargs['trust_remote_code'] = True

        from dataiku.huggingface.pipeline_text_gen_vllm import ModelPipelineTextGenerationVLLM
        can_use_vllm = ModelPipelineTextGenerationVLLM.supports_model(hf_handling_mode, model_name_or_path, base_model_name_or_path, model_settings)
        requested_engine = model_settings['engine']
        engine = requested_engine
        logger.info(f"Requested engine: {engine}")
        if engine == "AUTO":
            engine = "VLLM" if can_use_vllm else "TRANSFORMERS"
        logger.info(f"Selected engine: {engine}")
        if engine == "VLLM":
            pipeline = ModelPipelineTextGenerationVLLM(hf_handling_mode, model_name_or_path, base_model_name_or_path, model_settings, supports_image_inputs)
            pipeline.model_tracking_data["requested_engine"] = requested_engine
            pipeline.model_tracking_data["requested_quantization"] = requested_quantization
            return pipeline

    if model_settings['engine'] == "VLLM":
        logger.info(f"Requested engine 'VLLM' not supported for {hf_handling_mode}. Using 'TRANSFORMERS' instead")

    from dataiku.huggingface.pipeline_text_gen_lora import ModelPipelineTextGenerationLora
    if base_model_name_or_path is not None:
        return ModelPipelineTextGenerationLora(model_name_or_path, model_kwargs, batch_size, hf_handling_mode, base_model_name_or_path)

    if hf_handling_mode == "ZSC_GENERIC":
        from dataiku.huggingface.pipeline_zero_shot_classif import ModelPipelineZeroShotClassif
        return ModelPipelineZeroShotClassif(model_name_or_path, model_kwargs, batch_size)
    elif hf_handling_mode in ["TEXT_CLASSIFICATION_SENTIMENT", "TEXT_CLASSIFICATION_EMOTIONS", "TEXT_CLASSIFICATION_TOXICITY", "TEXT_CLASSIFICATION_PROMPT_INJECTION", "TEXT_CLASSIFICATION_OTHER"]:
        from dataiku.huggingface.pipeline_text_classif import ModelPipelineTextClassif
        return ModelPipelineTextClassif(model_name_or_path, model_kwargs, batch_size)

    elif hf_handling_mode == "SUMMARIZATION_GENERIC":
        from dataiku.huggingface.pipeline_summarization_generic import ModelPipelineSummarizationGeneric
        return ModelPipelineSummarizationGeneric(model_name_or_path, model_kwargs, batch_size)
    elif hf_handling_mode == "SUMMARIZATION_ROBERTA":
        from dataiku.huggingface.pipeline_summarization_roberta import ModelPipelineSummarizationRoberta
        return ModelPipelineSummarizationRoberta(model_name_or_path, model_kwargs, batch_size)
    elif hf_handling_mode in ["TEXT_GENERATION_FALCON", "TEXT_GENERATION_LLAMA_2", "TEXT_GENERATION_LLAMA_GUARD", "TEXT_GENERATION_MISTRAL", "TEXT_GENERATION_ZEPHYR", "TEXT_GENERATION_DOLLY", "TEXT_GENERATION_GEMMA", "TEXT_GENERATION_PHI_3", "TEXT_GENERATION_AUTO"]:
        from dataiku.huggingface.pipeline_text_gen import ModelPipelineTextGeneration
        if hf_handling_mode == "TEXT_GENERATION_DOLLY":
            logging.warning("Batching is not yet supported for Dolly, handling each prompt separately.")
            batch_size = 1
        # This is needed until a new version of transformers fixes the following issue,
        # see : https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/discussions/82
        if hf_handling_mode == "TEXT_GENERATION_PHI_3":
            model_kwargs['trust_remote_code'] = True
        pipeline = ModelPipelineTextGeneration(model_name_or_path, model_kwargs, batch_size, hf_handling_mode)
        pipeline.model_tracking_data["requested_engine"] = requested_engine
        pipeline.model_tracking_data["requested_quantization"] = requested_quantization
        return pipeline
    elif hf_handling_mode == "TEXT_GENERATION_MPT":
        from dataiku.huggingface.pipeline_text_gen_mpt import ModelPipelineTextGenerationMPT
        pipeline = ModelPipelineTextGenerationMPT(model_name_or_path, model_kwargs, batch_size, model_settings, not untrusted_model)
        pipeline.model_tracking_data["requested_engine"] = requested_engine
        pipeline.model_tracking_data["requested_quantization"] = requested_quantization
        return pipeline
    elif hf_handling_mode == "TEXT_GENERATION_GENERIC":
        from dataiku.huggingface.pipeline_text_gen_generic import ModelPipelineTextGenerationGeneric
        pipeline = ModelPipelineTextGenerationGeneric(model_name_or_path, model_kwargs, batch_size)
        pipeline.model_tracking_data["requested_engine"] = requested_engine
        pipeline.model_tracking_data["requested_quantization"] = requested_quantization
        return pipeline
    elif hf_handling_mode == "T5":
        from dataiku.huggingface.pipeline_text_gen_t5 import ModelPipelineTextGenerationT5
        return ModelPipelineTextGenerationT5(model_name_or_path, model_kwargs, batch_size)
    elif hf_handling_mode == "TEXT_EMBEDDING":
        from dataiku.huggingface.pipeline_text_embed_extraction import ModelPipelineTextEmbeddingExtraction
        return ModelPipelineTextEmbeddingExtraction(model_name_or_path, batch_size)
    elif hf_handling_mode == "IMAGE_EMBEDDING":
        from dataiku.huggingface.pipeline_image_embed_extraction import ModelPipelineImageEmbeddingExtraction
        return ModelPipelineImageEmbeddingExtraction(model_name_or_path, hf_model_name, use_dss_model_cache, batch_size)
    elif hf_handling_mode == "IMAGE_GENERATION_DIFFUSION":
        from dataiku.huggingface.pipeline_image_gen_diffusion import ModelPipelineImageGenerationDiffusion
        return ModelPipelineImageGenerationDiffusion.create_pipeline(model_name_or_path, model_settings, use_dss_model_cache, batch_size)

    raise Exception("Unknown handling mode for model (hf_handling_mode={hf_handling_mode}, hf_model_name={hf_model_name})".format(
        hf_handling_mode=hf_handling_mode, hf_model_name=hf_model_name))
