from typing import AsyncIterator

from dataiku.core import dkujson
from dataiku.llm.python.types import CompletionResponse, SimpleCompletionResponse, StreamCompletionResponse, StreamResponseChunkOrFooter
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter, DSSLLMCompletionResponse


def parse_single_completion_response(response: CompletionResponse) -> SimpleCompletionResponse:
    # Try to eat anything the user can throw at us

    if isinstance(response, str):
        return {"text": response}

    if isinstance(response, DSSLLMCompletionResponse):
        return response._raw

    if isinstance(response, dict):
        # try to convert other types to str
        text = response.get("text")
        if text is not None and not isinstance(text, str):
            try:
                response["text"] = dkujson.dumps(text)
            except TypeError as exc:
                raise Exception("Invalid type for the text field of the response (%s), it must be a string or an object that can be converted to JSON (%s)" % (type(text), exc))
        return response

    raise Exception("Unrecognized response type (%s), it must be a string, dict or DSSLLMCompletionResponse" % type(response))


def parse_stream_chunk_or_footer(response: StreamCompletionResponse) -> StreamResponseChunkOrFooter:
    # Try to eat anything the user can throw at us

    if isinstance(response, str):
        return {"chunk" : {"type": "content", "text": response}}

    if isinstance(response, DSSLLMStreamedCompletionChunk):
        response.data["type"] = response.data.get("type") or "content"
        return {"chunk" : response.data}

    if isinstance(response, DSSLLMStreamedCompletionFooter):
        response.data["type"] = response.data.get("type") or "footer"
        return {"footer" : response.data}

    if isinstance(response, dict):
        if "chunk" in response or "footer" in response:
            return response
        if "type" in response and response["type"] == "content":
            return {"chunk": response}
        if "type" in response and response["type"] == "footer":
            return {"footer": response}
        if "text" in response or "toolCalls" in response:
            response["type"] = response.get("type") or "content"
            return {"chunk" : response}

    raise Exception("Unrecognized stream response type (%s)" % type(response))


_footer_attributes = [
    # usage metadata
    "promptTokens",
    "completionTokens",
    "totalTokens",
    "tokenCountsAreEstimated",
    "estimatedCost",
    # specific to SimpleCompletionResponse
    "finishReason",
    "additionalInformation",
    "trace",
]


async def single_response_from_stream(stream: AsyncIterator[StreamCompletionResponse]) -> SimpleCompletionResponse:
    response = {}
    text_chunks = []
    tool_calls_map = {}
    tool_calls_list = []
    log_probs = []
    artifacts_list = []
    async for res in stream:
        res = parse_stream_chunk_or_footer(res)
        chunk = res.get("chunk")
        footer = res.get("footer")
        if chunk:
            # text
            if "text" in chunk and chunk["text"]:
                text_chunks.append(chunk["text"])

            # toolCalls
            if "toolCalls" in chunk and chunk["toolCalls"]:
                for chunk_tool_call in chunk["toolCalls"]:
                    if "index" not in chunk_tool_call or chunk_tool_call["index"] is None:
                        # tool call does not have an index, we won't be able to aggregate chunks, so assume it's full
                        tool_calls_list.append(chunk_tool_call)

                    else:
                        index = chunk_tool_call["index"]

                        if index not in tool_calls_map:
                            # tool call not yet in map: insert it
                            tool_calls_map[index] = chunk_tool_call

                        else:
                            # tool call already in map: update it
                            tool_call = tool_calls_map[index]
                            if ("type" not in tool_call or not tool_call["type"]) and ("type" in chunk_tool_call and chunk_tool_call["type"]):
                                tool_call["type"] = chunk_tool_call["type"]
                            if ("id" not in tool_call or tool_call["id"] is None) and ("id" in chunk_tool_call and chunk_tool_call["id"] is not None):
                                tool_call["id"] = chunk_tool_call["id"]

                            if "function" in chunk_tool_call and chunk_tool_call["function"]:
                                chunk_function_info = chunk_tool_call["function"]
                                if "function" not in tool_call or not tool_call["function"]:
                                    # function info not already listed: insert it
                                    tool_call["function"] = chunk_function_info
                                else:
                                    # function info already listed: update it
                                    function_info = tool_call["function"]
                                    if ("name" not in function_info or not function_info["name"]) and ("name" in chunk_function_info and chunk_function_info["name"]):
                                        function_info["name"] = chunk_function_info["name"]

                                    if "arguments" in chunk_function_info and chunk_function_info["arguments"]:
                                        if "arguments" not in function_info or function_info["arguments"] is None:
                                            function_info["arguments"] = ""
                                        function_info["arguments"] += chunk_function_info["arguments"]

            # logProbs
            if "logProbs" in chunk and chunk["logProbs"]:
                log_probs.extend(chunk["logProbs"])

            # artifacts
            if "artifacts" in chunk and chunk["artifacts"]:
                artifacts_list.extend(chunk["artifacts"])

        if footer:
            for key in _footer_attributes:
                if key in footer and footer[key]:
                    response[key] = footer[key]

    if text_chunks:
        response["text"] = "".join(text_chunks)
    if tool_calls_map:
        tool_calls_list.extend(tool_calls_map.values())
    if tool_calls_list:    
        response["toolCalls"] = tool_calls_list
    if log_probs:
        response["logProbs"] = log_probs
    if len(artifacts_list) > 0:
        response["artifacts"] = artifacts_list

    return response


async def stream_from_single_response(response: CompletionResponse) -> AsyncIterator[StreamResponseChunkOrFooter]:
    response = parse_single_completion_response(response)

    chunk = {}
    for key in ["text", "logProbs", "toolCalls"]:
        if key in response and response[key]:
            chunk[key] = response[key]
    if chunk:
        chunk["type"] = "content"
        yield {"chunk": chunk}

    if "artifacts" in response and response["artifacts"]:
        yield {"chunk" : {"type": "content", "artifacts": response["artifacts"]} }

    footer = {}
    for key in _footer_attributes:
        if key in response and response[key]:
            footer[key] = response[key]
    if footer:
        footer["type"] = "footer"
        yield {"footer": footer}
