import copy
import logging
from typing import List

from dataiku.llm.python.blocks_graph import NextBlock
from dataiku.llm.python.types import ChatMessage
from . import StreamableLLMBlock

logger = logging.getLogger("dku.agents.blocks_graph")

class LLMRequestBlockHandler(StreamableLLMBlock):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)
        self.iteration_number = 1

    def process_stream(self, trace):
        logger.info("LLM Request block starting with config %s" % self.block_config)

        with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
            yield {"chunk": {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}

            completion = self.new_completion()
            cel_engine = self.standard_cel_engine()

            completion.with_context(self.turn.current_merged_context)

            self._interpolate_and_add_message(completion, "system", cel_engine,
                                              self.block_config.get("systemPromptBeforeHistory"))

            if self.block_config.get("passConversationHistory", True):
                completion.cq["messages"].extend(self.turn.initial_messages)
                completion.cq["messages"].extend(self.sequence_context.generated_messages)
            else:
                logger.debug("Not passing conversation history, the LLM will only act on interpolated prompts from the block's config")

            self._interpolate_and_add_message(completion, "system", cel_engine,
                                                  self.block_config.get("systemPromptAfterHistory"))
            self._interpolate_and_add_message(completion, "user", cel_engine,
                                              self.block_config.get("additionalUserMessage"))

            completion._settings = copy.deepcopy(self.block_config.get("completionSettings", {}))

            accumulated_text_output = ""
            accumulated_memory_fragments = []
            for ichunk in self._run_completion(completion, llm_trace):
                if ichunk.text is not None:
                    if self.block_config["streamOutput"]:
                        yield {"chunk":  {"text": ichunk.text}}
                    accumulated_text_output += ichunk.text

                if ichunk.memory_fragment:
                    memory_fragment_msg: ChatMessage = {
                        "role": "memoryFragment",
                        "memoryFragment": ichunk.memory_fragment
                    }
                    accumulated_memory_fragments.append(memory_fragment_msg)

                if ichunk.artifacts:
                    artifacts = ichunk.artifacts
                    for artifact in artifacts:
                        hierarchy: List = artifact.setdefault("hierarchy", [])
                        hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ?, "agentId": agent_id, "agentName": agent_name})
                    yield {"chunk": {"artifacts": artifacts}}

                if ichunk.sources:
                    sources = ichunk.sources
                    for source in sources:
                        hierarchy: List = source.setdefault("hierarchy", [])
                        hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ? , "agentId": agent_id, "agentName": agent_name})
                    self.sequence_context.sources.extend(sources)

            if self.block_config.get("streamOutput", False) and self.agent.config.get("newLineAfterBlockOutput", True):
                yield {"chunk":  {"text": "\n"}}

            self._handle_output_mode_at_end(accumulated_text_output, accumulated_memory_fragments)

        yield NextBlock(id=self.block_config.get("nextBlock", None))