import asyncio
import inspect
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncIterator, Callable, Coroutine, Type

from dataiku.base.async_generator_wrapper import generate_in_executor
from dataiku.llm.python import BaseLLM
from dataiku.llm.python.processing import _NotImplementedError
from dataiku.llm.python.processing.base_processor import BaseStreamProcessor
from dataiku.llm.python.types import CompletionResponse, CompletionSettings, ProcessSinglePromptCommand, SimpleCompletionResponse, SingleCompletionQuery, StreamResponseChunkOrFooter
from dataiku.llm.python.utils import parse_single_completion_response, parse_stream_chunk_or_footer, single_response_from_stream, stream_from_single_response
from dataiku.llm.tracing import new_trace, SpanBuilder


logger = logging.getLogger(__name__)


class CompletionProcessor(BaseStreamProcessor[BaseLLM, ProcessSinglePromptCommand, SimpleCompletionResponse, ProcessSinglePromptCommand, StreamResponseChunkOrFooter]):
    _not_implemented: set

    def __init__(self, clazz: Type[BaseLLM], executor: ThreadPoolExecutor, config: dict, pluginConfig: dict, trace_name: str):
        super().__init__(clazz=clazz, executor=executor, config=config, pluginConfig=pluginConfig, trace_name=trace_name)

    def get_inference_params(self, command: ProcessSinglePromptCommand) -> dict:
        query = command.get("query", None)
        if query is None:
            raise Exception(f"'query' missing from command {command}")
        settings = command.get("settings", None)
        if settings is None:
            raise Exception(f"'settings' missing from command {command}")
        return { "query" : query, "settings" : settings }

    def get_async_inference_func(self) -> Callable[..., Coroutine]:
        return self._instance.aprocess

    def get_sync_inference_func(self) -> Callable:
        return self._instance.process

    def parse_raw_response(self, raw_response: CompletionResponse) -> SimpleCompletionResponse:
        return parse_single_completion_response(response=raw_response)

    async def process_query(self, command: ProcessSinglePromptCommand) -> SimpleCompletionResponse:
        result = await super().process_query(command)
        inference_params = self.get_inference_params(command)
        trace = result["trace"]
        if inference_params.get('settings', {}).get('outputTrajectory', False):
            trajectory = CompletionProcessor.to_trajectory(trace)
            additional_information = result.setdefault('additionalInformation', {})
            additional_information['trajectory'] = trajectory
        return result

    @staticmethod
    def to_trajectory(trace: dict):
        result = []
        # We match exactly the sequence of spans from the root level. Not very flexible, but safer
        # TODO: we might want to be more lax and look for any tool call in the tree, possibly restricting ourselves to the first level. see SC-256561 #NOSONAR
        if trace.get('name', "") == 'DKU_AGENT_CALL' and CompletionProcessor._get_first_child(trace, 'LangGraph'):
            # looks like a langgraph create_react_agent or ToolsUsingAgent (dku visual agent) call
            result = CompletionProcessor.extract_visual_agent_trajectory(trace)
        elif trace.get('name', "") == 'DKU_AGENT_CALL' and CompletionProcessor._get_first_child(trace, 'DKU_AGENT_ITERATION'):
            # looks like a Dataiku visual agent v2 call
            result = CompletionProcessor.extract_visual_agent_v2_trajectory(trace)
        elif trace.get('name', "") == 'DKU_AGENT_CALL' and CompletionProcessor._get_first_child(trace, 'AgentExecutor'):
            # looks like a code agent with langchain's create_openai_tools_agent call
            # TODO: handle CodeAgent/LangChain see SC-256561 #NOSONAR
            result = []
        return result

    @staticmethod
    def _get_first_child(root: dict, child_name: str) -> dict:
        for child in root.get('children', []):
            if child.get('name') == child_name:
                return child
        return {}

    @staticmethod
    def extract_visual_agent_trajectory(dku_agent_call: dict) -> list[dict]:
        trajectory = []
        langgraph = CompletionProcessor._get_first_child(dku_agent_call, 'LangGraph')
        requested_tools = []
        for span in langgraph.get('children', []):
            if span.get('name', '') == 'agent': # we expect alternation of 'agent' (requesting tool calls) and 'tools' (doing the calls)
                if requested_tools:
                    logger.warning("Agent requested some tools that were not called : %s" % requested_tools)
                runnable_sequence = CompletionProcessor._get_first_child(span, 'RunnableSequence')
                requested_tools = list(runnable_sequence.get('outputs', {}).get('output', {}).get('tool_calls', {}))
                if not requested_tools: # second best option, we get it from additional_kwargs
                    requested_tools = list(runnable_sequence.get('outputs', {}).get('output', {}).get('additional_kwargs', {}).get('tool_calls', {}))
            elif span.get('name', '') == 'tools':
                for tool_span in span.get('children', []): # actually not always a tool call. DKU starts with a dummy tu_toolcall, LangGraph sometimes adds a dummy _write, etc...
                    tool_name = tool_span.get('name', '')
                    tool_was_requested = False # will stay false if this is not a tool call
                    if not tool_name:
                        logger.warning("The called tool do not have a 'name' property, this should not happen: %s" % tool_span)
                    else:
                        for i, tool in enumerate(requested_tools):
                            if tool.get('name', tool.get('function', {}).get('name', '')) == tool_name: # second syntax is for the additional_kwargs version
                                requested_tools.pop(i)
                                tool_was_requested = True
                                break
                    if tool_was_requested: # TODO: else ? Do we _know_ we can only have tools and tu_toolcall, _write ?
                        dku_structured_tool = CompletionProcessor._get_first_child(tool_span, 'DKUStructuredTool')
                        dku_managed_tool_call = CompletionProcessor._get_first_child(dku_structured_tool, 'DKU_MANAGED_TOOL_CALL')
                        tool_attributes =  dku_managed_tool_call.get('attributes', {})
                        tool_call = {
                            'call_name' : tool_span['name'], # tool_name + _ + random_id
                            'begin': tool_span.get('begin', ''),
                            'end': tool_span.get('end', ''),
                            'duration': tool_span.get('duration', ''),
                            'inputs': tool_span.get('inputs', ''),
                            'outputs': tool_span.get('outputs', ''),
                            'attributes': tool_attributes,
                        }
                        trajectory += [tool_call]
        return trajectory

    @staticmethod
    def extract_visual_agent_v2_trajectory(dku_agent_call: dict) -> list[dict]:
        trajectory = []

        # TODO @new-agentic-loop Parse trajectory for the new trace format, when we've settled on a new trace format - the below is just a placeholder

        for iteration_span in dku_agent_call.get('children', []):
            for span in iteration_span.get('children', []):
                if span.get('name', '') == 'DKU_AGENT_LLM_CALL': # we expect alternation of 'DKU_AGENT_LLM_CALL' (requesting tool calls) and 'DKU_AGENT_TOOL_CALLS' (doing the calls)
                    pass
                elif span.get('name', '') == 'DKU_AGENT_TOOL_CALLS':
                    for dku_managed_tool_call in span.get('children', []):
                        tool_attributes =  dku_managed_tool_call.get('attributes', {})
                        tool_call = {
                            'call_name' : dku_managed_tool_call['name'], # tool_name + _ + random_id
                            'begin': dku_managed_tool_call.get('begin', ''),
                            'end': dku_managed_tool_call.get('end', ''),
                            'duration': dku_managed_tool_call.get('duration', ''),
                            'inputs': dku_managed_tool_call.get('inputs', ''),
                            'outputs': dku_managed_tool_call.get('outputs', ''),
                            'attributes': tool_attributes,
                        }
                        trajectory += [tool_call]
                else:
                    logger.warning("Unexpected trace node found when creating visual agent trajectory: %s" % span.get('name', ''))

        return trajectory

    async def process_query_stream(self, process_command: ProcessSinglePromptCommand) -> AsyncIterator[StreamResponseChunkOrFooter]:
        query = process_command.get("query")
        settings = process_command.get("settings")

        trace = new_trace(self._trace_name)
        trace.attributes["class"] = type(self._instance).__name__
        trace.__enter__()

        footer_emitted = False

        emitted_chunks = 0

        try:
            async for resp in self._aprocess_stream(query, settings, trace):
                emitted_chunks += 1

                if "footer" in resp and resp["footer"]:
                    footer_emitted = True

                    if "trace" not in resp["footer"]:
                        trace.__exit__(None, None, None)
                        resp["footer"]["trace"] = trace.to_dict()

                yield resp
        except Exception as e:
            trace.__exit__(None, None, None)
            yield {"footer": {"trace": trace.to_dict()}}
            raise e

        logger.info("PQS: user's code function completed, footer_emitted=%s emitted_chunks=%s", footer_emitted, emitted_chunks)
        if not footer_emitted:
            trace.__exit__(None, None, None)
            yield {"footer": {"trace": trace.to_dict()}}

    async def _aprocess(self, inference_params: dict, trace: SpanBuilder) -> SimpleCompletionResponse:
        inference_params["trace"] = trace
        if "aprocess" not in self._not_implemented:
            if not inspect.iscoroutinefunction(self._instance.aprocess):
                raise TypeError("'aprocess' should be a coroutine function")
            try:
                raw_response = await self._instance.aprocess(**inference_params)
                return self.parse_raw_response(raw_response)
            except _NotImplementedError:
                self._not_implemented.add("aprocess")

        if "process" not in self._not_implemented:
            if (
                (not callable(self._instance.process))
                or inspect.iscoroutinefunction(self._instance.process)
                or inspect.isgeneratorfunction(self._instance.process)
                or inspect.isasyncgenfunction(self._instance.process)
            ):
                raise TypeError("'process' should be a sync function")
            try:
                raw_response = await asyncio.get_running_loop().run_in_executor(self._executor, lambda: self._instance.process(**inference_params))
                return self.parse_raw_response(raw_response)
            except _NotImplementedError:
                self._not_implemented.add("process")

        if "aprocess_stream" not in self._not_implemented:
            if not inspect.isasyncgenfunction(self._instance.aprocess_stream):
                raise TypeError("'aprocess_stream' should be an async generator function")
            try:
                raw_stream = self._instance.aprocess_stream(**inference_params)
                return await single_response_from_stream(raw_stream)
            except _NotImplementedError:
                self._not_implemented.add("aprocess_stream")

        if "process_stream" not in self._not_implemented:
            if not inspect.isgeneratorfunction(self._instance.process_stream):
                raise TypeError("'process_stream' should be a generator function")
            try:
                raw_stream = generate_in_executor(self._instance.process_stream(**inference_params), self._executor)
                return await single_response_from_stream(raw_stream)
            except _NotImplementedError:
                self._not_implemented.add("process_stream")

        raise Exception("The LLM class should implement at least one of 'aprocess' (async, non-stream), 'process' (sync, non-stream), 'aprocess_stream' (async, stream), 'process_stream' (sync, stream)")

    async def _aprocess_stream(self, query: SingleCompletionQuery, settings: CompletionSettings, trace: SpanBuilder) -> AsyncIterator[StreamResponseChunkOrFooter]:
        if "aprocess_stream" not in self._not_implemented:
            if not inspect.isasyncgenfunction(self._instance.aprocess_stream):
                raise TypeError("'aprocess_stream' should be an async generator function")
            try:
                async for chunk_or_footer in self._instance.aprocess_stream(query, settings, trace):
                    yield parse_stream_chunk_or_footer(chunk_or_footer)
                return
            except _NotImplementedError:
                self._not_implemented.add("aprocess_stream")

        if "process_stream" not in self._not_implemented:
            if not inspect.isgeneratorfunction(self._instance.process_stream):
                raise TypeError("'process_stream' should be a generator function")
            try:
                async for chunk_or_footer in generate_in_executor(self._instance.process_stream(query, settings, trace), self._executor):
                    yield parse_stream_chunk_or_footer(chunk_or_footer)
                return
            except _NotImplementedError:
                self._not_implemented.add("process_stream")

        if "aprocess" not in self._not_implemented:
            if not inspect.iscoroutinefunction(self._instance.aprocess):
                raise TypeError("'aprocess' should be a coroutine function")
            try:
                raw_response = await self._instance.aprocess(query, settings, trace)
                async for chunk_or_footer in stream_from_single_response(raw_response):
                    yield chunk_or_footer
                return
            except _NotImplementedError:
                self._not_implemented.add("aprocess")

        if "process" not in self._not_implemented:
            if (
                (not callable(self._instance.process))
                or inspect.iscoroutinefunction(self._instance.process)
                or inspect.isgeneratorfunction(self._instance.process)
                or inspect.isasyncgenfunction(self._instance.process)
            ):
                raise TypeError("'process' should be a sync function")
            try:
                raw_response = await asyncio.get_running_loop().run_in_executor(self._executor, self._instance.process, query, settings, trace)
                async for chunk_or_footer in stream_from_single_response(raw_response):
                    yield chunk_or_footer
                return
            except _NotImplementedError:
                self._not_implemented.add("process")

        raise Exception("The LLM class should implement at least one of 'aprocess_stream' (async, stream), 'process_stream' (sync, stream), 'aprocess' (async, non-stream), 'process' (sync, non-stream)")
