import asyncio

from concurrent.futures import ThreadPoolExecutor
from typing import Type, Union

from .custom.base_llm import BaseLLM
from dataiku.llm.python.processing.completion_processor import CompletionProcessor
from dataiku.llm.python.types import CompletionSettings, ProcessSinglePromptCommand, SingleCompletionQuery, SimpleCompletionResponse

def _get_processor(agent_class: Type[BaseLLM], config: dict = {}, pluginConfig: dict = {}) -> CompletionProcessor:
    executor = ThreadPoolExecutor(1)
    completion_processor = CompletionProcessor(
        clazz=agent_class,
        executor=executor,
        config=config,
        pluginConfig=pluginConfig,
        trace_name="DKU_TEST_AGENT_CALL"
    )
    return completion_processor

def _get_prompt_command(query: SingleCompletionQuery, settings: CompletionSettings) -> ProcessSinglePromptCommand:
    command: ProcessSinglePromptCommand = {
        "query": query,
        "settings": settings,
        "stream": False,
        "type": "process-completion-query",
        "prompt": ""
    }

    return command

def run_completion_query(llm_class: Type[BaseLLM], query: Union[str, SingleCompletionQuery], settings: CompletionSettings = {}, config: dict = {}, pluginConfig: dict = {}) -> SimpleCompletionResponse:
    if isinstance(query, str):
        query = {
            "messages": [
                {
                    "role": "user",
                    "content": query
                }
            ]
        }
    completion_processor = _get_processor(llm_class, config, pluginConfig)
    command = _get_prompt_command(query, settings)

    with ThreadPoolExecutor(1) as executor:
        return executor.submit(asyncio.run, completion_processor.process_query(command)).result()