from typing import List

from dataiku.llm.python.blocks_graph import BlockHandler
import json, logging

from dataiku.llm.python.blocks_graph.microcel import MicroCelEngine
from dataiku.llm.python.tools_using_2 import InternalCompletionChunk
from dataiku.llm.python.utils import get_completion_query_safe_for_logging
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter, DSSLLMCompletionQuery
from dataikuapi.dss.llm_tracing import SpanBuilder

logger = logging.getLogger("dku.agents.blocks_graph")

class StreamableLLMBlock(BlockHandler):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

    def new_completion(self) -> DSSLLMCompletionQuery:
        llm_id = self.block_config.get("llmId")
        if llm_id is None or llm_id == "":
            raise ValueError("Please select a valid LLM on the block %s" % self.block_config["id"])

        llm = self.agent.project.get_llm(llm_id)
        return llm.new_completion()

    def _run_completion(self, completion, trace: SpanBuilder):
        logger.debug(f"About to run completion: {get_completion_query_safe_for_logging(completion.cq)}")
        logger.debug(f"With settings: {completion.settings}")

        for raw_chunk in completion.execute_streamed():
            logger.debug("Got raw chunk: %s" % raw_chunk)
            text = raw_chunk.data.get("text", "")
            tool_call_chunks = raw_chunk.data.get("toolCalls", [])
            artifacts = raw_chunk.data.get("artifacts", [])
            sources = raw_chunk.data.get("additionalInformation", {}).get("sources", [])
            memory_fragment = raw_chunk.data.get("memoryFragment")
            if isinstance(raw_chunk, DSSLLMStreamedCompletionFooter):
                trace.append_trace(raw_chunk.trace)
            yield InternalCompletionChunk(text, tool_call_chunks, artifacts, sources, memory_fragment)

        logger.debug("completion done")

    def _interpolate_and_add_message(self, completion, role: str, cel_engine: MicroCelEngine, message: str):
        from dataiku.llm.python.blocks_graph.utils import interpolate_cel, default_if_blank

        message = default_if_blank(message, None)
        if message is not None:
            message = interpolate_cel(cel_engine, message)
            logger.info(f"Interpolated {role} message: {message}")
            completion.with_message(message, role)

    def _handle_output_mode_at_end(self, accumulated_text_output: str, block_turn_generated_messages: List):
        try:
            json_output = json.loads(accumulated_text_output)
        except Exception:
            json_output = None

        self.sequence_context.last_text_output = accumulated_text_output
        logger.info("Block response stored in last_text_output")

        output_mode = self.block_config.get("outputMode", None)
        if output_mode == 'NONE':
            pass  # Do not store the output anywhere, just keep it in last_text_output
        elif output_mode == "ADD_TO_MESSAGES":
            self.sequence_context.generated_messages.extend(block_turn_generated_messages)
            if accumulated_text_output:
                assistant_message = {
                    "role": "assistant",
                    "content": accumulated_text_output
                }
                self.sequence_context.generated_messages.append(assistant_message)
            logger.info("Block response added to generated messages")
        elif output_mode == "SAVE_TO_STATE":
            output_key = self.block_config["outputStateKey"]
            self.turn.state_set(output_key, json_output or accumulated_text_output)
            logger.info(f"Block response saved to state key '{output_key}'")
        elif output_mode == "SAVE_TO_SCRATCHPAD":
            output_key = self.block_config["outputScratchpadKey"]
            self.sequence_context.scratchpad[output_key] = json_output or accumulated_text_output
            logger.info(f"Block response saved to scratchpad key '{output_key}'")
        else:
            raise ValueError(f"Unsupported output mode for {self.block_config['type']} block: {output_mode}")
