import functools
from typing import Dict, Generator

import dataiku
from backend.models.events import EventKind


def add_history_to_completion(
    completion,
    messages,
):
    # TODO improve to make sure we always have messages of type dict or MessageRead
    for msg in messages:
        if isinstance(msg, dict):
            completion.with_message(message=msg["content"], role=msg["role"])
        elif hasattr(msg, "role") and hasattr(msg, "content"):
            completion.with_message(message=msg.content, role=msg.role)
    return completion


@functools.lru_cache(maxsize=128)
def get_llm_friendly_name(llm_id: str, project_key: str) -> str:
    project = dataiku.api_client().get_project(project_key)
    llms = project.list_llms()

    for llm in llms:
        if llm.get("id") == llm_id:
            return llm.get("friendlyName", llm_id)
    return ""


def get_tool_validation_requests_msg(tool_validation_requests):
    return {
        "role": "toolValidationRequests",
        "toolValidationRequests": tool_validation_requests,
    }


def get_memory_fragment_msg(memory_fragment):
    return {
        "role": "memoryFragment",
        "memoryFragment": memory_fragment,
    }


def get_tool_validation_responses_msg(tool_validation_responses):
    return {
        "role": "toolValidationResponses",
        "toolValidationResponses": tool_validation_responses,
    }


def add_completion_msgs(comp, messages):
    for m in messages:
        if m["role"] in ["system", "user", "assistant"]:
            comp = comp.with_message(m["content"], role=m["role"])
        if m["role"] == "memoryFragment":
            comp = comp.with_memory_fragment(m["memoryFragment"])
        if m["role"] == "toolValidationResponses":
            for response in m["toolValidationResponses"]:
                comp = comp.with_tool_validation_response(
                    validation_request_id=response["validationRequestId"],
                    validated=response["validated"],
                    arguments=None,
                    # arguments=response.get("arguments", {}),
                )
        if m["role"] == "toolValidationRequests":
            comp = comp.with_tool_validation_requests(m["toolValidationRequests"])
    return comp