import logging

from typing import AsyncIterator

from vllm.entrypoints.openai.serving_models import OpenAIServingModels

from dataiku.base.utils import package_is_at_least
from dataiku.huggingface.vllm_backend.utils import get_model_config
from dataiku.base.async_link import FatalException
from dataiku.huggingface.pipeline_vllm import ModelPipelineVLLM
from dataiku.huggingface.types import ProcessSingleEmbeddingCommand
from dataiku.huggingface.types import ProcessSingleEmbeddingResponse
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.vllm_backend.model_params import get_model_params_loader_for_embedding

logger = logging.getLogger(__name__)


class ModelPipelineTextEmbeddingVLLM(ModelPipelineVLLM[ProcessSingleEmbeddingCommand, ProcessSingleEmbeddingResponse]):
    def __init__(
            self,
            model_name_or_path,
            model_settings: ModelSettings,
    ):
        super().__init__(model_name_or_path, None, model_settings)

        self.model_tracking_data["task"] = "text-embedding"

        model_params_loader = get_model_params_loader_for_embedding(self.transformers_model_config, model_settings)
        load_params = {
            "disable_log_stats": False,  # Enable logging of performance metrics
            "model": self.model_to_load,
            "runner": "pooling",
            **model_params_loader.build_params()
        }
        self.load_model(load_params)

    async def build_openai_server(self, openai_serving_models: OpenAIServingModels):
        vllm_model_config = await get_model_config(self.engine_client)

        # Create the OpenAI Embedding API server of vLLM
        import vllm
        if package_is_at_least(vllm, "0.12.0"):
            # `vllm.entrypoints.openai.serving_embedding` was renamed in https://github.com/vllm-project/vllm/pull/29634
            from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
        else:
            from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding

        if package_is_at_least(vllm, "0.11.1"):
            self.openai_server = OpenAIServingEmbedding(
                engine_client=self.engine_client,
                models=openai_serving_models,
                request_logger=None,
                chat_template=None,
                chat_template_content_format="auto",
            )
        else:
            self.openai_server = OpenAIServingEmbedding(
                engine_client=self.engine_client,
                model_config=vllm_model_config,
                models=openai_serving_models,
                request_logger=None,
                chat_template=None,
                chat_template_content_format="auto",
            )

    async def run_test_query(self):
        from dataiku.huggingface.vllm_backend.oai_mappers import generate_dummy_dss_embedding_request

        # -------------------------------
        # TEST QUERY
        prompt = "Explain in simple terms what Generative AI is and why prompts matter for it"
        logger.info(
            "Testing embedding engine with basic prompt: {input_text}".format(input_text=prompt)
        )
        request = generate_dummy_dss_embedding_request(prompt)

        result = None
        async for result in self.run_single_async(request=request):
            pass

        if not result or not result.get("embedding"):
            raise FatalException(
                "Something went wrong at initialization. Embedding engine did not return any result for basic prompt"
            )
        logger.info("Embedding engine test executed successfully: {result}".format(result=result["embedding"]))
        # -------------------------------

    async def run_query(
        self, request: ProcessSingleEmbeddingCommand
    ) -> AsyncIterator[ProcessSingleEmbeddingResponse]:
        from dataiku.huggingface.vllm_backend.oai_mappers import (
            dss_to_oai_embedding_request,
            oai_to_dss_embedding_response,
        )
        request_id = request["id"]

        logger.info("Start embedding {request_id}".format(request_id=request_id))
        assert self.openai_server

        oai_request = dss_to_oai_embedding_request(
            request,
        )
        response = await self.openai_server.create_embedding(oai_request)
        yield oai_to_dss_embedding_response(response)
        logger.info("Done embedding {request_id}".format(request_id=request_id))
