import asyncio
import inspect
import logging

from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Coroutine, Generic, Type, TypeVar, TypedDict, cast

from dataiku.llm.python.custom.base_model import BaseModel
from dataiku.llm.python.processing import _NotImplementedError
from dataiku.llm.python.types import LLMMeshTraceSpan
from dataiku.llm.tracing import new_trace, SpanBuilder

class OutputTypeWithTrace(TypedDict):
    trace: LLMMeshTraceSpan

ModelType = TypeVar('ModelType', bound=BaseModel)
InputType = TypeVar('InputType')
OutputType = TypeVar('OutputType', bound=OutputTypeWithTrace)

logger = logging.getLogger(__name__)

class _NotSupportedError(_NotImplementedError):
    pass

def raise_not_implemeted_error():
    raise _NotImplementedError

def raise_not_supported_error():
    raise _NotSupportedError

class BaseProcessor(ABC, Generic[ModelType, InputType, OutputType]):
    _not_implemented: set
    _instance: ModelType
    _executor: ThreadPoolExecutor
    _trace_name: str
    _sync_inference_func: Callable
    _async_inference_func: Callable[..., Coroutine]

    def __init__(self, clazz: Type[ModelType], executor: ThreadPoolExecutor, config: dict, pluginConfig: dict, trace_name: str):
        self._not_implemented = set()
        self._instance = clazz()
        self._executor = executor
        self._trace_name = trace_name
        try:
            self._instance.set_config(config, pluginConfig)
        except _NotImplementedError:
            pass

        self._async_inference_func = self.get_async_inference_func()
        if not inspect.iscoroutinefunction(self._async_inference_func):
                raise TypeError(f"'{self._async_inference_func.__name__}' should be a coroutine function")

        self._sync_inference_func = self.get_sync_inference_func()
        if (
            (not callable(self._sync_inference_func))
            or inspect.iscoroutinefunction(self._sync_inference_func)
            or inspect.isgeneratorfunction(self._sync_inference_func)
            or inspect.isasyncgenfunction(self._sync_inference_func)
        ):
            raise TypeError(f"'{self._sync_inference_func.__name__}' should be a sync function")

    @abstractmethod
    def get_inference_params(self, command: InputType) -> dict:
        raise _NotImplementedError

    @abstractmethod
    def get_sync_inference_func(self) -> Callable:
        raise _NotImplementedError

    @abstractmethod
    def get_async_inference_func(self) -> Callable[..., Coroutine]:
        raise _NotImplementedError

    @abstractmethod
    def parse_raw_response(self, raw_response: Any) -> OutputType:
        raise _NotImplementedError

    async def process_query(self, command: InputType) -> OutputType:
        inference_params = self.get_inference_params(command)

        trace = new_trace(self._trace_name)
        trace.__enter__()

        result = await self._aprocess(inference_params, trace)

        logger.info("Returning result of type: %s", type(result))

        trace.__exit__(None, None, None)
        result["trace"] = cast(LLMMeshTraceSpan, trace.to_dict())
        return result

    async def process_query_stream(self, process_command: Any) -> AsyncIterator[Any]:
        raise _NotSupportedError
        # small trick to stop the type checker from annoying us, the yield statement is mandatory
        # see https://github.com/microsoft/pyright/issues/9949 for explanation
        yield

    async def _aprocess(self, inference_params: dict, trace: SpanBuilder) -> OutputType:
        inference_params["trace"] = trace
        if "async_inference" not in self._not_implemented:
            try:
                raw_response = await self._async_inference_func(**inference_params)
                return self.parse_raw_response(raw_response)
            except _NotImplementedError:
                self._not_implemented.add("async_inference")

        if "sync_inference" not in self._not_implemented:
            try:
                raw_response = await asyncio.get_running_loop().run_in_executor(self._executor, lambda: self._sync_inference_func(**inference_params))
                return self.parse_raw_response(raw_response)
            except _NotImplementedError:
                self._not_implemented.add("sync_inference")

        raise Exception(f"The {self._instance.__class__.__name__} class should implement at least one of '{self._async_inference_func.__name__}' (async, non-stream), '{self._sync_inference_func.__name__}' (sync, non-stream).")



StreamInputType = TypeVar('StreamInputType')
StreamOutputType = TypeVar('StreamOutputType')

class BaseStreamProcessor(BaseProcessor[ModelType, InputType, OutputType], Generic[ModelType, InputType, OutputType, StreamInputType, StreamOutputType]):
    @abstractmethod
    async def process_query_stream(self, process_command: StreamInputType) -> AsyncIterator[StreamOutputType]:
        # small trick to stop the type checker from annoying us, the yield statement is mandatory
        # see https://github.com/microsoft/pyright/issues/9949 for explanation
        yield raise_not_implemeted_error()
