# Utility functions to convert between DSS data model <-> vLLM's OpenAI API data model
import logging

import uuid
from asyncio import CancelledError
from typing import AsyncIterator, List, Union, Optional, Dict, Any

from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    StreamOptions,
    ErrorResponse,
    ChatCompletionResponse,
    ChatCompletionStreamResponse,
    ChatCompletionLogProbs,
    UsageInfo,
    ResponseFormat,
    JsonSchemaResponseFormat,
    ChatCompletionResponseChoice,
    ToolCall
)
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatCompletionContentPartParam,
    ChatCompletionMessageToolCallParam, # type: ignore
    ChatCompletionToolMessageParam, # type: ignore
)

from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.types import ProcessSinglePromptCommand, StreamResponseChunkOrFooter
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.types import StreamedCompletionResponseFooter
from dataiku.huggingface.types import DetailedLogProb
from dataiku.huggingface.types import UsageData
from dataiku.huggingface.types import FinishReason
from dataiku.huggingface.types import ProcessSinglePromptResponseTextFull
from dataiku.huggingface.types import ChatMessage
from dataiku.huggingface.types import CompletionSettings
from dataiku.huggingface.types import ToolChoice
from dataiku.huggingface.types import FunctionToolCall
from dataiku.llm.python.types import SingleCompletionQuery

def _create_json_instruction_message(schema:Optional[Dict[str, Any]]) -> ChatMessage:
    if schema is not None:
        return {
            "role": "system",
            "content": "Answer in JSON format and follow this JSON schema: %s" % schema,
        }
    else:
        return {
            "role": "system",
            "content": "Answer in JSON format.",
        }

logger = logging.getLogger(__name__)

def dss_to_oai_request(
    request: ProcessSinglePromptCommand,
    chat_template_renderer: ChatTemplateRenderer,
    json_constraints_in_prompt: bool,
    guided_decoding_backend: str,
    tools_supported: bool
) -> ChatCompletionRequest:
    params = request["settings"]
    messages = request["query"]["messages"]
    response_format: Optional[ResponseFormat] = None
    if 'responseFormat' in params and params["responseFormat"] and params["responseFormat"]["type"] == "json":
        # FIXME: performance bug in vLLM https://github.com/vllm-project/vllm/issues/8383
        schema = params["responseFormat"].get("schema")
        if json_constraints_in_prompt:
            json_instruction_message = _create_json_instruction_message(schema)
            messages.append(json_instruction_message)
        if schema:
            response_format = ResponseFormat(
                type="json_schema",
                json_schema=JsonSchemaResponseFormat(
                    name="json",  # vLLM ignores the name currently
                    schema=schema,
                    strict=params["responseFormat"].get("strict", True),  # vLLM ignores strictness currently
                ),
            )
        else:
            response_format = ResponseFormat(type="json_object")

    tools = None
    tool_choice = None
    if 'tools' in params and params['tools']:
        if not tools_supported:
            raise TypeError("Tools not supported with this model.")
        tools = params['tools']
        if "toolChoice" not in params or params["toolChoice"] is None:
            logger.warning("Tools present in completion settings with no 'toolChoice' specified. Defaulting 'toolChoice' to 'auto'.")
            tool_choice = "auto"
        else:
            tool_choice = dss_to_oai_tool_choice(params['toolChoice'])

    rearranged_messages = chat_template_renderer.rearrange_messages(messages)
    oai_messages = [dss_to_oai_message(message) for message in rearranged_messages]

    top_p = params.get("topP", 1.0)
    top_k = params.get("topK", 50)
    if top_k is None:
        top_k = 50
    temperature = params.get("temperature", 1.0)
    presence_penalty = params.get("presencePenalty", 0.0)
    frequency_penalty = params.get("frequencyPenalty", 0.0)
    logit_bias = {str(k): v for k, v in (params.get("logitBias") or {}).items()}
    logprobs = params.get("logProbs")
    top_logprobs = params.get("topLogProbs", 0 if logprobs else None)
    max_tokens = params.get("maxOutputTokens")
    stream = request["stream"]
    stream_options = (
        StreamOptions(include_usage=True, continuous_usage_stats=False)
        if stream
        else None
    )
    stop = params.get("stopSequences")
    return ChatCompletionRequest(
        messages=oai_messages,
        model="model",
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        logit_bias=logit_bias,
        logprobs=logprobs,
        top_logprobs=top_logprobs,
        max_tokens=max_tokens,
        stream=stream,
        n=1,
        stream_options=stream_options,
        stop=stop,
        seed=None,
        response_format=response_format,
        guided_decoding_backend=guided_decoding_backend,
        **({
            'tools': tools,
            'tool_choice': tool_choice,
            'chat_template': chat_template_renderer.get_chat_template(tools)
        } if tools else {})
    )


def dss_to_oai_message(message: ChatMessage) -> ChatCompletionMessageParam:
    if "content" in message and message["content"] is not None:
        return {
            "role": message["role"],
            "content": message["content"],
        }
    elif "parts" in message and message["parts"] is not None:
        oai_parts: List[ChatCompletionContentPartParam] = []
        for part in message["parts"]:
            if part["type"] == "TEXT":
                if "text" in part and part["text"] is not None:
                    oai_parts.append({"type": "text", "text": part["text"]})
                else:
                    raise ValueError("'TEXT' part should contain a 'text' field")
            elif part["type"] == "IMAGE_INLINE":
                if "inlineImage" in part and part["inlineImage"] is not None:
                    # vLLM does not care about the mime type
                    image_url = f'data:image/anytype;base64,{part["inlineImage"]}'
                    oai_parts.append({"type": "image_url", "image_url": {"url": image_url}})
                else:
                    raise ValueError("'IMAGE_INLINE' part should contain an 'inlineImage' field")
            elif part["type"] == "IMAGE_URI":
                if "imageUrl" in part and part["imageUrl"] is not None:
                    if part["imageUrl"].startswith("http://") or part["imageUrl"].startswith("https://"):
                        raise ValueError("Multimodal image part type 'IMAGE_URI' does not support non data url for local huggingface models.\nUse data url or switch to 'IMAGE_INLINE'")
                    elif part["imageUrl"].startswith("data:image/"):
                        oai_parts.append({"type": "image_url", "image_url": {"url": part["imageUrl"]}})
                    else:
                        raise ValueError("Unsupported image URI")
                else:
                    raise ValueError("'IMAGE_URI' part should contain an 'imageUrl' field")
            else:
                raise ValueError(f"Unsupported part type: {part['type']}")
        return {
            "role": message["role"],
            "content": oai_parts,
        }
    elif "toolCalls" in message and message["toolCalls"] is not None:
        dss_tool_calls = message["toolCalls"]
        oai_tool_calls: List[ChatCompletionMessageToolCallParam] = []
        for tool_call in dss_tool_calls:
            if "function" not in tool_call:
                raise ValueError("Tool call must contain a 'function' key")
            oai_tool_calls.append({
                "type": "function",
                "id": tool_call["id"] if "id" in tool_call else "",
                "function": {
                    "arguments": tool_call["function"]["arguments"] if "arguments" in tool_call["function"] else "",
                    "name": tool_call["function"]["name"] if "name" in tool_call["function"] else "",
                },
            })
        return {
            "role": message["role"],
            "tool_calls": oai_tool_calls
        }
    elif "toolOutputs" in message and message["toolOutputs"] is not None:
        dss_tool_outputs = message["toolOutputs"]
        oai_tool_outputs: List[ChatCompletionToolMessageParam] = []
        for tool_output in dss_tool_outputs:
            oai_tool_outputs.append({
                "role": "tool",
                "content": tool_output["output"],
                "tool_call_id": tool_output["callId"],
            })
        return oai_tool_outputs[0]
    else:
        raise ValueError("Message must have a key in ['content', 'parts', 'toolCalls', 'toolOutputs']")


def oai_to_dss_logprobs(
    logprobs: Optional[ChatCompletionLogProbs],
) -> Optional[List[DetailedLogProb]]:
    if logprobs is None or logprobs.content is None:
        return None

    return [
        {
            "token": logprob.token or "",
            "logProb": logprob.logprob,
            "topLogProbs": [
                {
                    "token": top_logprob.token or "",
                    "logProb": top_logprob.logprob,
                }
                for top_logprob in (logprob.top_logprobs or [])
            ],
        }
        for logprob in logprobs.content
    ]


def oai_to_dss_usage(usage: UsageInfo) -> UsageData:
    return {
        "promptTokens": usage.prompt_tokens,
        "completionTokens": usage.completion_tokens or 0,
    }


def oai_to_dss_finish_reason(oai_finish_reason: Optional[str]) -> FinishReason:
    if oai_finish_reason == "length" or oai_finish_reason == "stop" or oai_finish_reason == "tool_calls":
        return oai_finish_reason
    else:
        return "unknown"

def dss_to_oai_tool_choice(dss_tool_choice: ToolChoice):
    if dss_tool_choice["type"] == "auto":
        return "auto"
    elif dss_tool_choice["type"] == "none":
        return "none"
    elif dss_tool_choice["type"] == "required":
        return "required"
    elif dss_tool_choice["type"] == "tool_name":
        return {
            "type": "function",
            "function": {"name": dss_tool_choice["name"]}
        }

def _single_oai_tool_call_to_dss_tool_call(oai_tool_call: ToolCall) -> FunctionToolCall:
    return {
        "id": oai_tool_call.id,
        "type": "function",
        "function": {
            "name": oai_tool_call.function.name,
            "arguments": oai_tool_call.function.arguments,
        }
    }

def oai_to_dss_tool_calls(choice: ChatCompletionResponseChoice) -> List[FunctionToolCall]:
    tool_calls = choice.message.tool_calls
    return [_single_oai_tool_call_to_dss_tool_call(oai_tool_call) for oai_tool_call in tool_calls]



async def oai_to_dss_response(
    response: Union[ChatCompletionResponse, AsyncIterator[str], ErrorResponse],
) -> AsyncIterator[ProcessSinglePromptResponseText]:
    if isinstance(response, ErrorResponse):
        if "Client disconnected" in response.message:
            raise CancelledError(response.message)
        raise Exception(response.message)

    if isinstance(response, ChatCompletionResponse):
        yield oai_to_dss_full_response(response)
    else:
        async for resp_chunk in oai_to_dss_stream_response(response):
            yield resp_chunk


def oai_to_dss_full_response(
    response: ChatCompletionResponse,
) -> ProcessSinglePromptResponseTextFull:
    choice = response.choices[0]
    dss_response: ProcessSinglePromptResponseTextFull = {
        "text": choice.message.content or ""
    }

    dss_logprobs = oai_to_dss_logprobs(choice.logprobs)
    if dss_logprobs is not None:
        dss_response["logProbs"] = dss_logprobs

    dss_response["usage"] = oai_to_dss_usage(response.usage)
    dss_response["finishReason"] = oai_to_dss_finish_reason(choice.finish_reason)
    dss_response["toolCalls"] = oai_to_dss_tool_calls(choice)

    return dss_response


async def oai_to_dss_stream_response(
    response: AsyncIterator[str],
) -> AsyncIterator[StreamResponseChunkOrFooter]:
    # Parse OpenAI output of vLLM (somewhat inefficient...)
    prefix = "data: "
    oai_usage = None
    oai_finish_reason = None
    async for chunk in response:
        if chunk.startswith(prefix) and len(chunk) > len(prefix):
            chunk_data = chunk[len(prefix) :]
            if chunk_data.startswith("[DONE]"):
                break
            parsed_chunk = ChatCompletionStreamResponse.model_validate_json(chunk_data)
            if parsed_chunk.usage:
                oai_usage = parsed_chunk.usage

            if parsed_chunk.choices:
                choice = parsed_chunk.choices[0]

                if choice.finish_reason:
                    oai_finish_reason = choice.finish_reason

                dss_chunk = {
                    "chunk": {
                        "text": choice.delta.content or "",
                    },
                    "footer": None,
                }

                dss_logprobs = oai_to_dss_logprobs(choice.logprobs)
                if dss_logprobs is not None:
                    dss_chunk["chunk"]["logProbs"] = dss_logprobs

                if choice.delta and choice.delta.tool_calls is not None:
                    dss_chunk["chunk"]["toolCalls"] = []
                    for delta_tool_call in choice.delta.tool_calls:
                        if delta_tool_call.function is not None:
                            tool_call: FunctionToolCall = {
                                "type": "function",
                                "id": delta_tool_call.id,
                                "index": delta_tool_call.index,
                                "function": {
                                    "name": delta_tool_call.function.name or "",
                                    "arguments": delta_tool_call.function.arguments or "",
                                }
                            }
                            dss_chunk["chunk"]["toolCalls"].append(tool_call)

                yield dss_chunk

    dss_footer: StreamedCompletionResponseFooter = {
        "type": "footer",
        "finishReason": oai_to_dss_finish_reason(oai_finish_reason),
    } # type: ignore

    if oai_usage is not None:
        usage: UsageData = oai_to_dss_usage(oai_usage)
        dss_footer["promptTokens"] = usage["promptTokens"]
        dss_footer["completionTokens"] = usage["completionTokens"]
        dss_footer["totalTokens"] = usage["promptTokens"] + usage["completionTokens"]
    dss_footer["finishReason"] = oai_to_dss_finish_reason(oai_finish_reason)
    yield { "footer": dss_footer }


def generate_dummy_dss_request(
    prompt: str, max_tokens=64
) -> ProcessSinglePromptCommand:
    query: SingleCompletionQuery = {
        "messages": [
            {
                "role": "user",
                "content": prompt,
            }
        ],
        "context": {},
    }
    settings: CompletionSettings = {
        "maxOutputTokens": max_tokens,
    }
    request: ProcessSinglePromptCommand = {
        "id": str(uuid.uuid4()),
        "query": query,
        "settings": settings,
        "stream": False,
        "prompt": "",
        "type": "process-completion-query",
    }
    return request
