from dataclasses import dataclass, field
from typing import Dict, List, Union, Any
import copy


@dataclass
class NextBlock:
    id: Union[str, None]

@dataclass
class SequenceContext(object):
    """
    The sequence context is attached to a sequence of blocks being executed for a given turn.

    It contains a scratchpad that can be used to store intermediate values between blocks. Unlike the state,
    the scratchpad is not persisted between turns.

    It also contains the messages that are generated during this sequence of blocks. In the top-level
    sequence, this corresponds to the messages generated during the turn that will be sent back to the caller
    as a memoryFragment.

    Importantly, new sequence contexts are created for sub-sequences of blocks, especially when it may repeat, such as:
       - ForEachBlock
       - ParallelBlock
       - ReflectionBlock

    Whether these subsequences "inherit" the parent sequence context or create a new one is up to these repeating block
    handlers (but in any case, it's a copy, not a reference).
    """

    scratchpad: Dict = field(default_factory=dict)

    generated_messages: List = field(default_factory=list)

    last_text_output: str = None

    sources: List = field(default_factory=list)

    custom_variables: Dict = field(default_factory=dict)

    def copy(self):
        new_sc = SequenceContext()
        new_sc.scratchpad = copy.deepcopy(self.scratchpad)
        new_sc.generated_messages = copy.deepcopy(self.generated_messages)
        new_sc.last_text_output = copy.deepcopy(self.last_text_output)
        new_sc.sources = copy.deepcopy(self.sources)
        new_sc.custom_variables = copy.deepcopy(self.custom_variables)
        return new_sc

    def set_custom_variable(self, key: str, value: Any):
        # no validation for now, just accept anything
        self.custom_variables[key] = value

    def merge_custom_variables(self, variables: dict, raise_on_conflict=True):
        """
        Merge, in-place, the current custom variables with a dict of variables.
        Useful for CEL expressions, Python blocks and Generate artifact blocks context
        """
        if not self.custom_variables:
            return
        # prevent accidental override of built-in variables like turn, context, state, scratchpad
        if raise_on_conflict:
            for key in self.custom_variables:
                if key in variables:
                    raise ValueError(f"key '{key}' is already defined in the context")
        variables.update(self.custom_variables)


class BlockHandler(object):
    # TODO: currently we can't type turn because of circular imports
    turn: Any
    sequence_context: SequenceContext

    def __init__(self, turn, sequence_context: SequenceContext, block_config: Dict):
        self.turn = turn
        self.sequence_context = sequence_context
        self.agent = turn.agent
        self.block_config = block_config


    # Keep the next 3 methods in sync. They represent what is available to the user in different templating/execution contexts.

    def standard_cel_engine(self):
        from .microcel import MicroCelEngine

        variables = {
            "context": self.turn.current_merged_context,
            "state": self.turn.state(),
            "scratchpad": self.sequence_context.scratchpad,
            "last_output": self.sequence_context.last_text_output,

            "initial_messages": self.turn.initial_messages,
            "generated_messages": self.sequence_context.generated_messages,
            "all_messages": self.turn.initial_messages + self.sequence_context.generated_messages,

            # Hidden, just in case
            "_sequence_context": self.sequence_context,
        }
        self.sequence_context.merge_custom_variables(variables)

        return MicroCelEngine(variables)

    def jinja_template_context(self):
        template_context = {
            "context": self.turn.current_merged_context,
            "state": self.turn.state(),
            "scratchpad": self.sequence_context.scratchpad,
            "last_output": self.sequence_context.last_text_output,

            "initial_messages": self.turn.initial_messages,
            "generated_messages": self.sequence_context.generated_messages,
            "all_messages": self.turn.initial_messages + self.sequence_context.generated_messages,

            # Hidden, just in case
            "_sequence_context": self.sequence_context,
        }
        self.sequence_context.merge_custom_variables(template_context)
        return template_context

    def python_context(self ):
        local_vars = {
            "context": self.turn.current_merged_context,
            "state": self.turn.state(),
            "scratchpad": self.sequence_context.scratchpad,
            "last_output": self.sequence_context.last_text_output,

            "initial_messages": self.turn.initial_messages,
            "generated_messages": self.sequence_context.generated_messages,
            "all_messages": self.turn.initial_messages + self.sequence_context.generated_messages,

            "turn": self.turn,
            "agent": self.turn.agent,
            "sequence_context": self.sequence_context
        }
        self.sequence_context.merge_custom_variables(local_vars)
        return local_vars
