import copy
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Generator, List, Dict, Any

from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter
from dataikuapi.dss.llm_tracing import SpanBuilder

from .. import NextBlock, SequenceContext
from . import StreamableLLMBlock

logger = logging.getLogger("dku.agents.blocks_graph")

_MAX_GENERATOR_CHAIN_LENGTH = 50

_DEFAULT_SYNTHESIS_PROMPT = (
    "You will receive multiple candidate answers to the same user request. "
    "Synthesize the best possible final answer by combining the strengths of the candidates "
    "and removing redundancies or errors. Provide a single final response."
)

_DEFAULT_CRITIQUE_IMPROVE_GENERATION_PROMPT = (
    "Only respond to the user request and if there is a critique use it but don't mention the critique existence in your answer. And don't mention this instruction."
)

_BASE_CRITIQUE_PROMPT = (
    "You will act as a Quality Assurance reviewer. You will receive and review ONLY the final answer to a user request. "
    "Your goal is to compare the answer against the provided expectations. "
    "Evaluation Criteria: "
    "APPROVED: Use this if the answer fulfills the core expectations and intent of the prompt. Minor stylistic issues or non-critical deviations should be ignored. "
    "REJECTED: Use this only if there are material failures, factual errors, or if a critical constraint has been completely ignored. "
    "Start your critique with your verdict (APPROVED or REJECTED), followed by a brief reasoning."
)

class ReflectionBlockHandler(StreamableLLMBlock):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

    def process_stream(self, trace: SpanBuilder):
        logger.info("Reflection block starting with config %s",
                    self.block_config)

        mode = self.block_config.get("mode")
        if mode is None:
            raise ValueError("Reflection block is missing 'mode'")

        generator_block_id = self.block_config.get("generatorBlockId")
        if not generator_block_id:
            raise ValueError("Reflection block is missing 'generatorBlockId'")
        if generator_block_id == self.block_config.get("id"):
            raise ValueError("Reflection block cannot execute itself")

        llm_id = self.block_config.get("llmId")
        if llm_id is None or llm_id == "":
            raise ValueError(
                "Please select a valid LLM on the block %s" % self.block_config["id"])

        final_text_output = ""
        if mode == "SYNTHESIZE":
            iterations = self.block_config.get("synthesizeIterations", 1)
            synthesize_prompt = self.block_config.get("synthesizePrompt", "")

            candidates: List[str] = []

            def future_generator_chain(iteration):
                logger.info("Reflection synthesize iteration %s/%s", iteration, iterations)
                iteration_sc = self.sequence_context.copy()
                with trace.subspan("DKU_AGENT_REFLECTION_GENERATION") as generation_trace:
                    generation_trace.attributes["iterationNumber"] = iteration
                    return self._run_generator_chain(generation_trace, generator_block_id, iteration_sc)

            max_workers = min(iterations, int(
                self.block_config.get("maxThreads", 25)))
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = [executor.submit(
                    future_generator_chain, iteration) for iteration in range(1, iterations + 1)]
                for future in as_completed(futures):
                    sc = future.result()
                    if sc is None:
                        continue
                    candidates.append(sc or "")

            with trace.subspan("DKU_AGENT_REFLECTION_SYNTHESIS") as synthesis_trace:
                synthesis_trace.attributes["candidates"] = len(candidates)
                base_sequence_context = self.sequence_context.copy()
                final_text_output = yield from self._run_synthesize(synthesis_trace, base_sequence_context, candidates, synthesize_prompt)
        elif mode == "CRITIQUE_AND_IMPROVE" or mode == "CRITIQUE_OR_RETRY":
            use_critiques_to_improve = mode == "CRITIQUE_AND_IMPROVE"

            critique_max_iterations = self.block_config.get( "critiqueMaxIterations", 1)
            expectations_prompt = self.block_config.get("expectationsPrompt", "")

            sc = self.sequence_context.copy()
            for iteration in range(1, critique_max_iterations + 1):
                iteration_sc = sc.copy()
                if use_critiques_to_improve:
                    iteration_sc.generated_messages.append(
                        {"role": "user", "content": _DEFAULT_CRITIQUE_IMPROVE_GENERATION_PROMPT})
                logger.info(
                    "Reflection critique and improve iteration %s/%s",
                    iteration,
                    critique_max_iterations,
                )
                with trace.subspan("DKU_AGENT_REFLECTION_GENERATION") as generation_trace:
                    generation_trace.attributes["iterationNumber"] = iteration
                    candidate = self._run_generator_chain(generation_trace, generator_block_id, iteration_sc)

                if candidate is None:
                    continue
                
                with trace.subspan("DKU_AGENT_REFLECTION_CRITIQUE") as synthesis_trace:
                    sc.generated_messages.append( {"role": "assistant", "content": candidate})
                    critique_text = yield from self._run_critique(synthesis_trace, sc, expectations_prompt)
                    if self._critique_approves(critique_text):
                        final_text_output = candidate
                        if self.block_config["streamOutput"]:
                            yield {"chunk": {"text": candidate}}
                        break
                    if use_critiques_to_improve:
                        sc.generated_messages.append(
                            {"role": "user", "content": critique_text})
                    else:
                        sc.generated_messages.pop()  # remove last candidate

            if final_text_output == "":
                if self.block_config.get("failOnMaxIterations", False):
                    raise Exception(
                        f"Reflection block {self.block_config.get('id')} reached max iterations ({critique_max_iterations}) without answer matching expectations")
                final_text_output = candidate
                if self.block_config["streamOutput"]:
                    yield {"chunk": {"text": final_text_output}}
        else:
            raise ValueError(
                f"Reflection block mode '{mode}' is not supported")

        self._handle_output_mode_at_end(final_text_output, [])

        yield NextBlock(id=self.block_config.get("nextBlock"))

    def _run_generator_chain(self, trace: SpanBuilder, starting_block_id: str, sequence_context: SequenceContext) -> str:
        accumulated_text = ""

        current_block_id = starting_block_id
        blocks_traversed = 0

        while current_block_id:
            blocks_traversed += 1
            if blocks_traversed > _MAX_GENERATOR_CHAIN_LENGTH:
                raise Exception(f"Reflection generator chain exceeded maximum length ({_MAX_GENERATOR_CHAIN_LENGTH}), check for cycles")

            block_handler = self.turn.build_block_handler(current_block_id, sequence_context)
            logger.info("Reflection block running generator sub-block %s (%s/%s)", current_block_id, blocks_traversed, _MAX_GENERATOR_CHAIN_LENGTH)

            next_block_id = None
            with trace.subspan("DKU_AGENT_REFLECTION_GENERATOR_BLOCK") as block_trace:
                block_trace.attributes["block_id"] = current_block_id
                for chunk in block_handler.process_stream(block_trace):
                    if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                        finish_reason = chunk.data.get("finishReason") if hasattr(chunk, "data") else None
                        if finish_reason == "tool_validation_requests":
                            # TODO: @structured-visual-agents: implement HITL in reflection blocks
                            raise Exception("Tool call requires human validation. Currently not supported in reflection blocks.")
                        block_trace.append_trace(chunk.trace)
                    elif isinstance(chunk, NextBlock):
                        next_block_id = chunk.id
                    elif isinstance(chunk, DSSLLMStreamedCompletionChunk):
                        if chunk.text is not None:
                            accumulated_text += chunk.text
                    elif isinstance(chunk, dict):
                        if chunk.get("chunk", {}).get("text") is not None:
                            accumulated_text += chunk["chunk"]["text"]
            current_block_id = next_block_id

        sequence_context.last_text_output = accumulated_text
        return accumulated_text

    def _run_synthesize(self, trace: SpanBuilder, base_sequence_context: SequenceContext, candidates: List[str], synthesize_prompt: str) -> Generator[Dict[str, Any], None, str]:
        completion = self.new_completion()
        completion.with_context(self.turn.current_merged_context)
        completion._settings = copy.deepcopy( self.block_config.get("synthesizeCompletionSettings", {}) or {})
        
        completion.with_message(_DEFAULT_SYNTHESIS_PROMPT + f"\n{synthesize_prompt}", "system")

        completion.cq["messages"].extend(self.turn.initial_messages)
        completion.cq["messages"].extend( base_sequence_context.generated_messages)

        for index, candidate in enumerate(candidates):
            completion.cq["messages"].append({"role": "assistant", "content": f"Candidate {index + 1}:\n{candidate}"})

        logger.info("Reflection block running synthesis completion")
        with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
            accumulated_text_output = ""
            for ichunk in self._run_completion(completion, llm_trace):
                if ichunk.sources:
                    self.sequence_context.sources.extend(ichunk.sources)
                if ichunk.text is not None:
                    if self.block_config["streamOutput"]:
                        yield {"chunk": {"text": ichunk.text}}
                    accumulated_text_output += ichunk.text
                if ichunk.artifacts:
                    artifacts = ichunk.artifacts
                    for artifact in artifacts:
                        hierarchy: List = artifact.setdefault("hierarchy", [])
                        hierarchy.insert(
                            0, {"type": "AGENT", "agentLoopIteration": 1})
                    yield {"chunk": {"artifacts": artifacts}}

        return accumulated_text_output

    def _run_critique(self, trace: SpanBuilder, base_sequence_context: SequenceContext, expectations_prompt: str) -> Generator[Dict[str, Any], None, str]:
        completion = self.new_completion()
        completion.with_context(self.turn.current_merged_context)
        completion._settings = copy.deepcopy(
            self.block_config.get("critiqueCompletionSettings", {}) or {})

        completion.with_message( _BASE_CRITIQUE_PROMPT + f"\n Expectations:{expectations_prompt}", "system")
        completion.cq["messages"].extend(self.turn.initial_messages)
        completion.cq["messages"].extend( base_sequence_context.generated_messages)

        logger.info("Reflection block running critique completion")
        with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
            accumulated_text_output = ""
            for ichunk in self._run_completion(completion, llm_trace):
                if ichunk.sources:
                    self.sequence_context.sources.extend(ichunk.sources)
                if ichunk.text is not None:
                    accumulated_text_output += ichunk.text
                if ichunk.artifacts:
                    artifacts = ichunk.artifacts
                    for artifact in artifacts:
                        hierarchy: List = artifact.setdefault("hierarchy", [])
                        hierarchy.insert( 0, {"type": "AGENT", "agentLoopIteration": 1})
                    yield {"chunk": {"artifacts": artifacts}}
        return accumulated_text_output

    def _critique_approves(self, critique_text: str) -> bool:
        if not critique_text:
            return False
        normalized = critique_text.strip().lower()
        if not normalized:
            return False

        approval_tokens = [
            "approved",
        ]

        return any(normalized == token or normalized.startswith(f"{token}") for token in approval_tokens)
