import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, Generator, Literal

import dataiku
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
from graphrag.query.llm.base import BaseLLM


class QueryDataikuChatLLM(BaseLLM):
    def __init__(self, llm_id: str):
        self.llm_id = llm_id
        self.client = dataiku.api_client()
        self.project = self.client.get_default_project()
        self.llm = self.project.get_llm(self.llm_id)
        self.logger = logging.getLogger(__name__)

    def _prepare_completion(self, messages: str | list[Any], **kwargs: Any):
        completion = self.llm.new_completion()
        if isinstance(messages, str):
            completion.with_message(messages, role="user")
        else:
            for msg in messages:
                role = msg.get("role", "user")
                content = msg.get("content", "")
                completion.with_message(content, role=role)

        if kwargs.get("json", False):
            completion.with_json_output()

        return completion

    def generate(
        self,
        messages: str | list[Any],
        streaming: bool = True,
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    ) -> str:
        self.logger.debug(f"messages: {messages}, kwargs: {kwargs}")
        completion = self._prepare_completion(messages, **kwargs)
        resp = completion.execute()
        clean_response = resp.text
        return clean_response

    def stream_generate(
        self,
        messages: str | list[Any],
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    ) -> Generator[str, None, None]:
        full_response = self.generate(messages, streaming=True, callbacks=callbacks, **kwargs)
        yield full_response

    async def agenerate(
        self,
        messages: str | list[Any],
        streaming: bool = True,
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    ) -> str:
        loop = asyncio.get_running_loop()

        def sync_generate():
            return self.generate(messages=messages, streaming=streaming, callbacks=callbacks, **kwargs)

        with ThreadPoolExecutor() as executor:
            result = await loop.run_in_executor(executor, sync_generate)
        return result

    async def astream_generate(
        self,
        messages: str | list[Any],
        callbacks: list[BaseLLMCallback] | None = None,
        **kwargs: Any,
    ) -> AsyncGenerator[str, None]:
        response = await self.agenerate(messages, streaming=True, callbacks=callbacks, **kwargs)
        yield response
