import copy
import hashlib
import json
import logging
import threading

from .. import NextBlock, BlockHandler

logger = logging.getLogger("dku.agents.blocks_graph.context_compression")

_COMPRESSION_CACHE = {
    "initial": {},
    "generated": {},
}
_COMPRESSION_CACHE_LOCK = threading.Lock()

DEFAULT_ACTIVE_BUFFER_SIZE = 5
DEFAULT_COMPRESSION_TRIGGER_CHARS = 1000

DEFAULT_COMPRESSION_PROMPT = """
Role: You are an expert context manager. Your goal is to compress a conversation history while retaining 100% of the functional utility for a specialized AI agent.

Guidelines:

* Prioritize Intent: Clearly state what the user is trying to achieve.

* Extract Entities & Facts: Retain specific names, dates, values, and preferences (e.g., "User prefers Python over Java" or "Project deadline is Oct 12").

* Preserve State: Note what has been resolved and what is still pending.

* Prune Noise: Remove greetings, filler words, and repetitive acknowledgments.

* Format: Keep the summary in a dense, bulleted list or a highly concise paragraph.


You will receive as input the existing messages to summarize; Note the that some of these messages may already be summaries.

Output: Provide only the updated summary representing well the entirety of the discussion.
"""


def _get_cached_entry(conversation_id, bucket):
    cache_key = "conversation:%s:%s" % (conversation_id, bucket)
    return _COMPRESSION_CACHE.get(cache_key)

def _update_cached_entry(conversation_id, bucket, entry):
    cache_key = "conversation:%s:%s" % (conversation_id, bucket)
    _COMPRESSION_CACHE[cache_key] = entry

def _message_signature(message):
    try:
        payload = json.dumps(message, sort_keys=True, default=str)
    except Exception:
        payload = str(message)
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()

def _messages_signatures(messages):
    return [_message_signature(message) for message in messages]


class ContextCompressionBlockHandler(BlockHandler):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

        self.active_buffer_size = self.block_config.get("activeBufferSize", DEFAULT_ACTIVE_BUFFER_SIZE)
        self.compression_trigger_chars = self.block_config.get("compressionTriggerChars", DEFAULT_COMPRESSION_TRIGGER_CHARS)
        self.apply_to_initial = self.block_config.get("applyToInitial", True)
        self.apply_to_generated = self.block_config.get("applyToGenerated", False)

    def process_stream(self, trace):
        conversation_id = self.turn.context_get("conversationId")

        if conversation_id is None:
            logger.warning("Conversation ID is not set in the turn context; context compression cannot be cached")

        if self.apply_to_initial:
            initial_cache = None
            if conversation_id is not None:
                with _COMPRESSION_CACHE_LOCK:
                    initial_cache = _get_cached_entry(conversation_id, "initial")

            if initial_cache is None:
                initial_cache = []

            (has_compressed, compressed_messages) = self._compress_messages(trace, self.turn.initial_messages, initial_cache)

            if has_compressed:
                logger.info("Initial messages were compressed from %d to %d messages", len(self.turn.initial_messages), len(compressed_messages))
                self.turn.initial_messages = compressed_messages
                if conversation_id is not None:
                    with _COMPRESSION_CACHE_LOCK:
                        _update_cached_entry(conversation_id, "initial", initial_cache)

        if self.apply_to_generated:
            generated_cache = None
            if conversation_id is not None:
                with _COMPRESSION_CACHE_LOCK:
                    generated_cache = _get_cached_entry(conversation_id, "generated")

            if generated_cache is None:
                generated_cache = []

            (has_compressed, compressed_messages) = self._compress_messages(trace, self.turn.generated_messages, generated_cache)

            if has_compressed:
                logger.info("Initial messages were compressed from %d to %d messages", len(self.turn.generated_messages), len(compressed_messages))
                self.turn.generated_messages = compressed_messages
                if conversation_id is not None:
                    with _COMPRESSION_CACHE_LOCK:
                        _update_cached_entry(conversation_id, "generated", generated_cache)

        yield NextBlock(id=self.block_config.get("nextBlock"))


    @staticmethod
    def _apply_cached_summaries(messages, cache_entries):
        while True:
            if not cache_entries:
                return messages
            signatures = _messages_signatures(messages)
            best_entry = None
            for entry in cache_entries:
                prefix = entry.get("prefix_signatures", [])
                if not prefix or len(prefix) > len(signatures):
                    continue
                if signatures[: len(prefix)] == prefix:
                    if best_entry is None or len(prefix) > len(best_entry.get("prefix_signatures", [])):
                        best_entry = entry
            if best_entry is None:
                logger.info("Done looking for cached summaries")
                return messages
            logger.info("Applying cached summary for prefix of length %d", len(best_entry.get("prefix_signatures", [])))
            prefix_len = len(best_entry.get("prefix_signatures", []))
            messages = [best_entry["summary_message"]] + messages[prefix_len:]


    def _summarize_messages(self, messages, trace):
        llm_id = self.block_config.get("llmId")
        if llm_id is None or llm_id == "":
            raise ValueError("Please select a valid LLM on the block %s" % self.block_config["id"])

        with trace.subspan("DKU_AGENT_CONTEXT_COMPRESSION_LLM_CALL") as llm_trace:

            llm = self.agent.project.get_llm(llm_id)
            completion =  llm.new_completion()
            completion.with_context(self.turn.current_merged_context)
            completion._settings = copy.deepcopy(self.block_config.get("completionSettings", {}))

            messages_text = "\n\n".join(json.dumps(message) for message in messages)

            completion.with_message(DEFAULT_COMPRESSION_PROMPT, role="system")
            completion.with_message(messages_text, role="user")

            resp = completion.execute()
            summary_text = resp.text.strip()

            llm_trace.append_trace(resp.trace)

        logger.info("Summarization done. Input chars: %d, Summary chars: %d", len(messages_text), len(summary_text))

        return summary_text

    @staticmethod
    def compute_message_size(message):
        # TODO: Replace by a more accurate size computation
        return len(json.dumps(message))

    def _compress_messages(self, trace, messages, cache_entries):
        initial_len = sum(self.compute_message_size(msg) for msg in messages)

        logger.info("Considering whether to compress %d messages (total size approx %d chars) with active buffer size %d and summary max chars %d",
                    len(messages), initial_len, self.active_buffer_size, self.compression_trigger_chars)

        logger.info("Applying cached summaries if any (have %d in cache)" % len(cache_entries))
        messages = self._apply_cached_summaries(messages, cache_entries)

        logger.info("After applying cached summaries, %d messages remain", len(messages))

        # We are below active buffer, no compression needed
        if self.active_buffer_size <= 0 or len(messages) <= self.active_buffer_size:
            logger.info("Number of messages is below active buffer size, no compression needed")
            return False, messages

        updated_len = sum(self.compute_message_size(msg) for msg in messages)
        logger.info("After applying cached summaries, total size approx %d chars", updated_len)

        if updated_len <= self.compression_trigger_chars:
            logger.info(f"Total size is below summary max chars ({self.compression_trigger_chars}), no compression needed")
            return False, messages

        logger.info("Compressing messages, summarizing all except last %d messages", self.active_buffer_size)

        prefix_len = len(messages) - self.active_buffer_size
        prefix_messages = messages[:prefix_len]
        buffer_messages = messages[prefix_len:]

        prefix_signatures = _messages_signatures(prefix_messages)
        summary_text = self._summarize_messages(prefix_messages, trace)

        summary_message = {
            "role": "user",
            "content": "Summary of our previous exchanges:\n %s" % summary_text
        }

        cache_entries.append({
            "prefix_signatures": prefix_signatures,
            "summary_message": summary_message,
        })

        return True, [summary_message] + buffer_messages
