import logging

from dataiku.llm.python.blocks_graph import NextBlock, BlockHandler
from dataiku.llm.python.blocks_graph.utils import default_if_blank, interpolate_cel, tool_has_been_called
from dataiku.llm.python.utils import get_completion_query_safe_for_logging

logger = logging.getLogger("dku.agents.blocks_graph")

class RoutingBlockHandler(BlockHandler):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

    def process_stream(self, trace):
        logger.info("Routing block starting with config %s" % self.block_config)
        next_block = self.get_next_block(trace)
        logger.info("Routing block next block is %s" % next_block)
        yield NextBlock(id=next_block)

    def _clause_matches(self, clause, trace) -> bool:
        if clause["type"] == "STATE_HAS_KEYS":
            return all(self.turn.state_get(key, None) is not None for key in clause["stateKeys"])
        if clause["type"] == "SCRATCHPAD_HAS_KEYS":
            return all(self.sequence_context.scratchpad.get(key, None) is not None for key in clause["scratchpadKeys"])
        elif clause["type"] == "TOOLS_CALLED":
            tools = [
                self.turn.agent.load_or_get_tool({"toolRef": tool_ref.get("toolRef")})
                for tool_ref in clause["toolRefs"]
            ]
            tool_names = []
            for prepared_tool in tools:
                for llm_tool in prepared_tool.llm_tools:
                    tool_names.append(llm_tool.llm_tool_name.split("__", 1)[0])
            return any(tool_has_been_called(self.sequence_context.generated_messages, tool_name) for tool_name in tool_names)
        elif clause["type"] == "EXPRESSION":
            if clause["expression"]["language"] == "CEL":
                engine = self.standard_cel_engine()
                return bool(engine.evaluate(clause["expression"]["expression"]))
            raise NotImplementedError("Expression language %s is not implemented yet" % clause["expression"]["language"])
        elif clause["type"] == "AND":
            return all(self._clause_matches(inner_clause, trace) for inner_clause in clause["clauses"])
        elif clause["type"] == "OR":
            return any(self._clause_matches(inner_clause, trace) for inner_clause in clause["clauses"])
        elif clause["type"] == "LLM_BASED":
            return self._llm_wants_to_go_to_next_block(clause, trace)
        else:
            raise Exception("Unknown routing block clause: %s" % clause)

    def get_next_block(self, trace):
        if self.block_config["routingMode"] == "CLAUSES":
            for cbd in self.block_config["clausesBasedDecisions"]:
                if self._clause_matches(cbd["clause"], trace):
                    return cbd.get("nextBlock")
            
            default_next_block = self.block_config.get("defaultNextBlockIfNoClauseMatch")
            if default_next_block:
                return default_next_block
            else:
                raise ValueError("Routing failed on the block %s, no clause matched and no default next block" % self.block_config["id"])
        elif self.block_config["routingMode"] == "LLM_DISPATCH":
            logger.info("LLM_Based dispatch")

            llm_id = self.block_config.get("llmId")
            if llm_id is None or llm_id == "":
                raise ValueError("Please select a valid LLM on the block %s" % self.block_config["id"])
            
            llm = self.turn.agent.project.get_llm(llm_id)
            completion = llm.new_completion()
            completion.with_context(self.turn.current_merged_context)

            system_prompt = default_if_blank(self.block_config.get("prompt"), None)
            if system_prompt is not None:
                cel_engine = self.standard_cel_engine()
                system_prompt = interpolate_cel(cel_engine, system_prompt)
                logger.info("Interpolated system prompt: %s" % system_prompt)
                completion.with_message(system_prompt, "system")

            completion.cq["messages"].extend(self.turn.initial_messages)
            completion.cq["messages"].extend(self.sequence_context.generated_messages)

            with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
                logger.info("About to run completion: %s" % get_completion_query_safe_for_logging(completion.cq))
                response = completion.execute()
                if response.trace:
                    llm_trace.append_trace(response.trace)
                text_output = response.text

            for dispatch in self.block_config["resultDispatch"]:
                if dispatch["key"] == text_output:
                    return dispatch["value"]
            raise ValueError("Routing failed, no result dispatch matched the LLM answer: %s" % text_output)
        elif self.block_config["routingMode"] == "EXPRESSION_DISPATCH":
            engine = self.standard_cel_engine()
            eval_result = engine.evaluate(self.block_config["expression"])
            logger.info("Expression %s evaluated to %s" % (self.block_config["expression"], eval_result))
            return eval_result
        else:
            raise Exception("Unknown routing mode %s" % self.block_config["routingMode"])

    def _llm_wants_to_go_to_next_block(self, clause, trace) -> bool:
        llm_id = self.block_config.get("llmId")
        if llm_id is None or llm_id == "":
            raise ValueError("Please select a valid LLM on the block %s in order to use an LLM clause" % self.block_config["id"])
        llm = self.agent.project.get_llm(llm_id)
        completion = llm.new_completion()
        completion.with_context(self.turn.current_merged_context)
        cel_engine = self.standard_cel_engine()

        system_prompt_before_history = default_if_blank(clause.get("systemPromptBeforeHistory"), "You need to decide on what the next step of the process is. Answer with YES to go to the next step.")
        system_prompt_before_history = interpolate_cel(cel_engine, system_prompt_before_history)
        logger.info("Interpolated system prompt before history: %s" % system_prompt_before_history)
        completion.with_message(system_prompt_before_history, "system")

        if clause.get("passConversationHistory", True):
            completion.cq["messages"].extend(self.turn.initial_messages)
            completion.cq["messages"].extend(self.sequence_context.generated_messages)
        else:
            logger.debug("Not passing conversation history, the LLM will only act on interpolated prompts from the block's config")

        system_prompt_after_history = default_if_blank(clause.get("systemPromptAfterHistory"), None)
        if system_prompt_after_history is not None:
            system_prompt_after_history = interpolate_cel(cel_engine, system_prompt_after_history)
            logger.info("Interpolated system prompt after history: %s" % system_prompt_after_history)
            completion.with_message(system_prompt_after_history, "system")

        additional_user_message = default_if_blank(clause.get("additionalUserMessage"), None)
        if additional_user_message is not None:
            additional_user_message = interpolate_cel(cel_engine, additional_user_message)
            logger.info("Interpolated additional user message: %s" % additional_user_message)
            completion.with_message(additional_user_message, "user")

        with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
            logger.info("About to run completion: %s" % get_completion_query_safe_for_logging(completion.cq))
            response = completion.execute()
            if response.trace:
                llm_trace.append_trace(response.trace)
            return response.text.strip().lower() in ["yes", "true"]
