import asyncio
import base64
import datetime
import json
import re
import threading
import time
import uuid
from contextlib import closing
from typing import Dict, List

import dataiku
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter
from flask import g, request

from backend.agents.user_agent import UserAgent
from backend.config import (
    agents_as_tools,
    get_conversation_vision_llm,
    get_enable_upload_documents,
    get_enterprise_agents,
    get_extraction_mode,
    get_global_system_prompt,
    get_guardrails_enabled,
    get_guardrails_pattern,
    get_quota_images_per_conversation,
    get_uploads_managedfolder_id,
)
from backend.constants import ExtractionMode
from backend.models.events import EventKind
from backend.schemas import schemas
from backend.services.derived_documents_service import (
    count_conversation_images,
    get_structured_documents,
    predict_new_image_count,
    process_derived_documents,
)
from backend.services.guardrails_service import process_documents_for_guardrails, emit_guardrails_filter_events
from backend.services.orchestrator_service import OrchestratorService
from backend.utils.conv_utils import normalise_stream_event
from backend.utils.events_utils import (
    get_chart_plans,
    get_references,
    get_selected_agents,
    get_used_agent_ids,
    get_used_tables,
)
from backend.utils.llm_utils import (
    get_memory_fragment_msg,
    get_tool_validation_requests_msg,
    get_tool_validation_responses_msg,
)
from backend.utils.logging_utils import extract_error_message, get_logger
from backend.utils.user_utils import get_agent_context
from backend.utils.utils import (
    build_agent_connect,
    call_dss_agent_full_conversation,
    get_user_base_llm,
    select_agents,
)
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter

logger = get_logger(__name__)


class ConversationError(Exception):
    """Raised when payload validation fails or other business rules break."""


def build_actions(events: list[dict]) -> dict:
    from flask import g

    actions = {}
    if events:
        chart_plans = get_chart_plans(events)
        if chart_plans:
            actions["chart_plans"] = chart_plans
        references = get_references(events)
        sel_agents = get_selected_agents(events)
        logger.info(f"Building message actions, references: {len(references)}")
        logger.info(f"Building message actions, selected agents: {len(sel_agents)}")
        if len(references) != 1 or len(sel_agents) != 1:
            # Only enable stories for one agent and one tool with references
            actions["stories"] = {"enable_stories": False}
        else:
            tables = get_used_tables(references[0])
            logger.info(f"Building message actions, tables: {tables}")
            eas = get_enterprise_agents(g.get("authIdentifier", None))
            agent = next((ea for ea in eas if ea.get("id") == sel_agents[0]["agentId"]), None)
            logger.info(f"Building message actions, agent: {agent}")
            if agent and agent.get("stories_workspace"):
                actions["stories"] = {
                    "enable_stories": True,
                    "tables": tables,
                    "agent_id": sel_agents[0]["agentId"],
                    "stories_workspace": agent.get("stories_workspace"),
                    "agent_short_example_queries": agent.get("agent_short_example_queries", []),
                    "agent_example_queries": agent.get("agent_example_queries", []),
                }
    return actions


class ConversationService:
    def __init__(self, store, draft_mode: bool = False):
        from backend.database.user_store_protocol import IUserStore

        self.store: IUserStore = store
        self.draft_mode = draft_mode  # True only in test contexts
        self.project = dataiku.api_client().get_default_project()
        self._user_agents_list = self.store.get_all_agents()
        self._user_agents_dict = {a.id: a for a in self._user_agents_list}

    def _missing_agents(self, agent_ids: List[str], user: str) -> List[str]:
        existing = {e["id"] for e in get_enterprise_agents(user)}
        existing.update(self._user_agents_dict.keys())
        return [aid for aid in agent_ids if aid not in existing]

    @staticmethod
    def _blank_conversation(conv_id: str, agent_ids: List[str]) -> schemas.FullConversationRead:
        return schemas.FullConversationRead(
            conversation_id=conv_id,
            user_id=str(g.get("authIdentifier")),
            title=None,
            agent_ids=agent_ids,
            llm_id="",
            agents_enabled=1,
            status=schemas.StatusEnum.ACTIVE,
            messages=[],
        )
        # return {"id": conv_id, "title": None, "messages": [], "agentIds": agent_ids}

    def _generate_title(self, messages: list[schemas.MessageBase], trace=None) -> str:
        # Create a subspan if trace is provided
        if trace:
            title_span = trace.subspan("generate_title")
            title_span.begin(int(time.time() * 1000))
        else:
            title_span = None

        today = datetime.date.today().strftime("%Y-%m-%d")
        llm_id = get_user_base_llm(self.store)

        # Add inputs to trace
        if title_span:
            title_span.inputs["llm_id"] = llm_id
            title_span.inputs["message_count"] = len(messages)
            title_span.inputs["date"] = today

        system_msg = (
            "You are a titling assistant. "
            f"Today is {today}. "
            "Return a single short title (≤7 words) that summarizes the conversation. "
            "Plain text ONLY: no prefixes (e.g., 'Title:'), no quotes, no emojis, "
            "no code fences, no extra lines, and no trailing punctuation."
        )
        user_msg = "Generate a short title for this conversation. Output only the title text."

        comp = self.project.get_llm(llm_id).new_completion()
        comp.with_message(system_msg, role="system")
        comp.with_message(user_msg, role="user")

        for m in messages:
            comp = comp.with_message(m.content, role=m.role)

        logger.info(
            "Generating conversation title:\nLLM id=[%s]\nCompletion Query=%s\nCompletion Settings=%s\n",
            llm_id,
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        resp = comp.execute()
        raw_title = resp.text or ""
        logger.info(
            "Conversation title LLM response:\nLLM id=[%s]\nCompletion Response=%s",
            llm_id,
            raw_title,
        )

        if title_span and resp.trace:
            title_span.append_trace(resp.trace)

        cleaned_title = self._clean_title(raw_title)

        # Add outputs to trace
        if title_span:
            title_span.outputs["raw_title"] = raw_title
            title_span.outputs["cleaned_title"] = cleaned_title
            title_span.end(int(time.time() * 1000))
        return cleaned_title

    # TODO this is redundant with add_history_to_completion in llm_utils.py
    def _build_history(
        self,
        agent_ids: list[str],
        conv_messages: list[dict],
        *,
        documents_map: dict[str, dict] | None = None,
    ) -> list[dict]:
        history: list[dict] = []
        docs_map = documents_map or {}
        for _msg in conv_messages:
            if isinstance(_msg, dict):
                msg = _msg
            else:
                msg = _msg.model_dump(mode="json")
            if (
                msg.get("tool_validation_requests")
                or msg.get("tool_validation_responses")
                or msg.get("memory_fragment")
            ):
                if msg.get("tool_validation_requests") or msg.get("tool_validation_responses"):
                    # skip tool validation messages
                    continue
                else:
                    # This is a memory fragment message, we want to send it back as part of the history
                    # Add memory fragment first
                    history.append(get_memory_fragment_msg(msg.get("memory_fragment", {})))
            entry = {
                "role": msg.get("role"),
                "content": msg.get("content", ""),
            }
            # TODO this should handled in a seperate func
            attachments = msg.get("attachments") or []
            doc_entries = []
            for att in attachments:
                path = att.get("document_path")
                if not path:
                    continue
                doc_info = docs_map.get(path)
                if not doc_info:
                    continue
                doc_entries.append(
                    {
                        "name": doc_info.get("name") or att.get("name") or path.rsplit("/", 1)[-1],
                        "snapshots": doc_info.get("snapshots") or [],
                        "text": doc_info.get("text") or "",
                    }
                )
            if doc_entries:
                entry["__documents__"] = doc_entries
            history.append(entry)
        if len(agent_ids) == 1:
            agent_id = agent_ids[0]
            # Try to get enterprise agent first
            # TODO this is unefficient we should already have the info about the agent type somewhere
            enterprise_agents = get_enterprise_agents(g.get("authIdentifier", None))
            enterprise_agent = next((ea for ea in enterprise_agents if ea.get("id") == agent_id), None)
            if enterprise_agent:
                # Enterprise agent: use agent_system_instructions
                agent_system_instructions = enterprise_agent.get("agent_system_instructions", "")
                if agent_system_instructions:
                    today = datetime.date.today().strftime("%Y-%m-%d")
                    sys_msg = {"role": "system", "content": f"Today is {today}. {agent_system_instructions}"}
                    return [sys_msg] + list(history)
                return history  # no system instructions
        today = datetime.date.today().strftime("%Y-%m-%d")
        global_prompt = get_global_system_prompt()
        sys_msg = {"role": "system", "content": f"Today is {today}.{global_prompt}"}
        return [sys_msg] + list(history)

    def handle_tool_validation_response(self, payload: dict, emit, cancel_event: threading.Event):
        logger.info("handle_tool_validation_response with Payload: %r", payload)
        conv_id = payload.get("conversationId")
        selected_llm = payload.get("selectedLLM", "")
        agent_ids = payload.get("agentIds")
        draft_mode = bool(payload.get("draftMode"))
        # Allow empty user message if attachments are provided
        if not conv_id or not agent_ids:
            emit("chat_error", {"error": "Missing conversationId or agent_ids"})
            return
        from dataiku.llm.tracing import new_trace

        trace = new_trace("DKU_AGENT_HUB_QUERY")
        trace.begin(int(time.time() * 1000))
        # -- draft-mode path (legacy full-rewrite) ----------------------
        if draft_mode:
            self._stream_draft(conv_id, agent_ids, "", emit, cancel_event,trace, payload)
        conv = self.store.get_conversation(conv_id)
        if not conv:
            emit("chat_error", {"error": f"Conversation {conv_id} not found"})
            return
        
        last_msg = conv.messages[-1] if conv.messages else None
        if not last_msg or not last_msg.tool_validation_requests:
            emit("chat_error", {"error": "No tool validation requests found in last message"})
            return

        # Get current responses and requests
        new_responses = [payload.get("toolValidationResponse")]

        current_responses = last_msg.tool_validation_responses or []
        total_requests = last_msg.tool_validation_requests or []
        # Accumulate responses
        all_tool_responses = current_responses + new_responses

        # If not all responses received, update message and return
        if len(all_tool_responses) < len(total_requests):
            logger.info(
                "Not all tool validation responses received yet (%d/%d), updating message and waiting for more",
                len(all_tool_responses),
                len(total_requests),
            )
            self.store.update_message(last_msg.id, {"tool_validation_responses": all_tool_responses})
            return

        # All responses received, update message and proceed
        self.store.update_message(last_msg.id, {"tool_validation_responses": all_tool_responses})

        user_login = g.get("authIdentifier") or ""
        documents_map: dict[str, dict] = {}
        try:
            documents_map = get_structured_documents(self.store, conv_id, user_login)
        except Exception:
            logger.exception("Failed to build structured prompt for conversation %s", conv_id)

        # prompt
        history = self._build_history(
            agent_ids,
            conv.messages,
            documents_map=documents_map,
        )
        # Important to keep this order
        history.append(get_memory_fragment_msg(last_msg.memory_fragment or {}))
        history.append(get_tool_validation_requests_msg(last_msg.tool_validation_requests or {}))
        history.append(get_tool_validation_responses_msg(all_tool_responses))

        events: list[dict] = []
        sel_agents, _ = select_agents(self.store, agent_ids, conv.messages, trace=trace)
        sel_ids = [a["agentId"] for a in sel_agents]
        # assistant streaming
        final_reply, artifacts_meta, tool_validation_requests, memory_fragment, context_upsert = self._stream_reply(
            conv_id, sel_agents, history, emit, events, cancel_event, trace
        )

        actions = build_actions(events)
        used_ids = get_used_agent_ids(events)
        assistant_msg = schemas.MessageCreate(
            id=str(uuid.uuid4()),
            conversation_id=conv_id,
            role="assistant",
            content=final_reply,
            event_log=events,
            actions=actions,
            artifacts=artifacts_meta,
            tool_validation_requests=tool_validation_requests,
            memory_fragment=memory_fragment,
            selected_agent_ids=sel_ids,
            used_agent_ids=used_ids,
            llm_id=selected_llm,
            agents_enabled=True,
            context_upsert=context_upsert,
        )
        # atomic write
        self.store.append_messages(conv_id, [assistant_msg])
        # End tracing
        trace.end(int(time.time() * 1000))

        # If we have an assistant message, update it with the trace info
        if assistant_msg is not None:
            assistant_msg.trace = trace.to_dict() if trace else {}
            self.store.update_message(assistant_msg.id, {"trace": assistant_msg.trace})

        # Emit chat_end to ensure frontend stops waiting, even if errors occurred
        if conv_id and not cancel_event.is_set() and not draft_mode:
            fresh_conv = self.store.get_conversation(conv_id) or conv
            title = fresh_conv.title if fresh_conv else "New Conversation"
            messages = fresh_conv.messages if fresh_conv else []

            emit(
                "chat_end",
                {
                    "agentIds": agent_ids or [],
                    "conversationId": conv_id,
                    "title": title,
                    "messages": [msg.model_dump(mode="json", by_alias=True) for msg in messages],
                    "modeAgents": True,
                    "selectedLLM": selected_llm or "",
                    "hasEventLog": bool(events),
                },
            )

    # ------------------------------------------------------------------ #
    #  Streaming send  (WebSocket)
    # ------------------------------------------------------------------ #
    def stream_message(self, payload: dict, emit, cancel_event: threading.Event):
        logger.info("stream_message called from %s with Payload: %r", request.remote_addr, payload)
        from dataiku.llm.tracing import new_trace

        trace = new_trace("DKU_AGENT_HUB_QUERY")
        trace.begin(int(time.time() * 1000))

        assistant_msg_dict = None
        conv = {}
        events: list[dict] = []
        agent_ids = None
        draft_mode = False
        if payload.get("toolValidationResponse"):
            return self.handle_tool_validation_response(payload, emit, cancel_event)
        try:
            draft_mode = bool(payload.get("draftMode"))
            conv_id = payload.get("conversationId")
            user_msg = (payload.get("userMessage") or "").strip()
            agents_enabled = payload.get("modeAgents", True)
            selected_llm = payload.get("selectedLLM", "")
            payload_attachments = payload.get("attachments") or []
            agent_ids = self._filter_active_agents(payload.get("agentIds") or [], allow_draft=draft_mode)
            logger.info(f"_filter_active_agents called, {agent_ids}")
            # -- validation -------------------------------------------------
            # Allow empty user message if attachments are provided
            if not conv_id or (not user_msg and not payload_attachments):
                emit("chat_error", {"error": "Missing conversationId or userMessage"})
                return
            miss = self._missing_agents(agent_ids, g.get("authIdentifier"))
            if miss:
                emit("chat_error", {"error": f"Agent(s) not found: {', '.join(miss)}"})
                return

            # -- draft-mode path (legacy full-rewrite) ----------------------
            if draft_mode:
                self._stream_draft(conv_id, agent_ids, user_msg, emit, cancel_event, trace)
                return

            # -------- non-draft incremental path ---------------------------
            conv = self.store.ensure_conversation_exists(conv_id, agent_ids, agents_enabled, selected_llm)
            old_agent_ids = list(conv.agent_ids or [])

            if not conv:
                self.store.ensure_conversation_exists(conv_id, agent_ids, agents_enabled, selected_llm)
                conv = self._blank_conversation(conv_id, agent_ids)

            # user msg
            user_msg_dict = schemas.MessageCreate(
                id=str(uuid.uuid4()),
                conversation_id=conv_id,
                role="user",
                content=user_msg,
                event_log=[],
                agents_enabled=agents_enabled,
                llm_id=selected_llm,
            )
            # attachments provided by the client (fresh upload metadata),
            if payload_attachments:
                user_msg_dict.attachments = payload_attachments
                # Store attachments in message_attachments table
                self.store.insert_or_update_message_attachments(user_msg_dict.id, json.dumps(payload_attachments))
                # Mark these attachments as 'attached' in derived_documents (if relevant)
                docs = self.store.get_derived_documents(conv_id)
                payload_doc_names = set([a["document_name"] for a in payload_attachments if "document_name" in a])
                for doc in docs:
                    if (
                        doc.document_name in payload_doc_names
                        and doc.document_metadata
                        and doc.document_metadata.get("status") == "uploaded"
                    ):
                        doc.document_metadata["status"] = "processed"
                        self.store.upsert_derived_document(
                            conv_id, doc.document_name, doc.document_path, doc.document_metadata
                        )

            conv.messages.append(user_msg_dict)

            # -------- Document Extraction and Quota Tracking --------
            # Shared events buffer (persisted in assistant message eventLog)
            # Get configuration parameters
            enable_document_upload = get_enable_upload_documents()
            image_quota = get_quota_images_per_conversation()
            extraction_mode = get_extraction_mode()

            # Only process documents if document upload is enabled
            if payload_attachments and enable_document_upload:
                # Count existing images in the conversation
                current_image_count = count_conversation_images(self.store, conv_id)

                # Initialize quota_exceeded flag (will be set to True if quota is exceeded)
                quota_exceeded = False

                # Predict how many new images would be generated from attachments (only if screenshots mode)
                predicted_new_images = 0
                if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS.value:
                    folder_id = get_uploads_managedfolder_id()
                    predicted_new_images = predict_new_image_count(payload_attachments, folder_id)
                    logger.info(
                        f"Predicted {predicted_new_images} new images from {len(payload_attachments)} attachments"
                    )

                    # Calculate total predicted image count
                    predicted_total = current_image_count + predicted_new_images

                    # Check if quota would be exceeded with new documents
                    if predicted_total > image_quota:
                        extraction_mode = ExtractionMode.PAGES_TEXT.value
                        quota_exceeded = True
                        logger.warning(
                            f"Image quota WOULD BE EXCEEDED for conversation {conv_id}. "
                            f"Predicted total: {predicted_total} > Quota: {image_quota}. "
                            f"Switching to TEXT-ONLY mode for all documents."
                        )
                        # Emit event to notify frontend that extraction mode changed to text-only
                        evt_mode_change = {
                            "eventKind": EventKind.EXTRACTION_MODE_CHANGED,
                            "eventData": {
                                "extractionMode": ExtractionMode.PAGES_TEXT.value,
                                "reason": "quota_exceeded",
                            },
                        }
                        emit("chat_event", {**evt_mode_change, "conversationId": conv_id})
                        events.append(evt_mode_change)
                    else:
                        logger.info(
                            f"Document extraction mode: {extraction_mode} for conversation {conv_id}. "
                            f"Current images: {current_image_count}/{image_quota}, "
                            f"Predicted new: {predicted_new_images}, "
                            f"Predicted total: {predicted_total}/{image_quota}"
                        )

                # Generate derived context (screenshots or text only) based on extraction mode
                # Emit document analysis event (live) and persist it in events buffer
                evt_da = {
                    "eventKind": EventKind.DOCUMENT_ANALYSIS,
                    "eventData": {"count": len(payload_attachments)},
                }
                emit("chat_event", {**evt_da, "conversationId": conv_id})
                events.append(evt_da)
                try:
                    process_derived_documents(
                        self.store,
                        conv_id,
                        g.get("authIdentifier") or "",
                        payload_attachments,
                        extraction_mode=extraction_mode,
                    )
                    # Store extraction_mode as JSON with quota_exceeded flag
                    self.store.insert_or_update_message_attachments(
                        user_msg_dict.id,
                        json.dumps(payload_attachments),
                        extraction_mode=extraction_mode,
                        quota_exceeded=quota_exceeded if extraction_mode == ExtractionMode.PAGES_TEXT.value else False,
                    )
                    # Mark all attachments as ready after successful processing
                    for att in payload_attachments:
                        att["uploadStatus"] = "ready"

                    # Log the actual new image count
                    new_image_count = count_conversation_images(self.store, conv_id)
                    images_added = new_image_count - current_image_count

                    if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS.value:
                        logger.info(
                            f"Added {images_added} images to conversation {conv_id}. "
                            f"Total: {new_image_count}/{image_quota} "
                            f"(predicted: {predicted_new_images}, actual: {images_added})"
                        )
                    else:
                        logger.info(
                            f"Processed {len(payload_attachments)} documents in TEXT-ONLY mode. "
                            f"Total images unchanged: {new_image_count}/{image_quota}"
                        )

                    evt_da = {
                        "eventKind": EventKind.DOCUMENT_ANALYSIS_COMPLETED,
                        "eventData": {"conv_id": conv_id, "documents": payload_attachments},
                    }
                    emit("chat_event", {**evt_da, "conversationId": conv_id})
                    events.append(evt_da)
                except Exception:
                    logger.exception("Failed to process derived documents for conversation %s", conv_id)

            self._process_conversation_guardrails(conv_id, emit, events)

            # Select based on settings
            sel_agents, justification = select_agents(self.store, agent_ids, conv.messages, trace=trace)
            sel_ids = [a["agentId"] for a in sel_agents]
            if justification and not agents_as_tools():
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": justification, "selection": sel_agents},
                }
                emit("chat_event", evt)
                events.append(evt)
            elif len(sel_agents):
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": "User selected", "selection": sel_agents},
                }
                # emit("chat_event", evt)
                events.append(evt)
            user_login = g.get("authIdentifier") or ""
            documents_map: dict[str, dict] = {}
            try:
                documents_map = get_structured_documents(self.store, conv_id, user_login)
            except Exception:
                logger.exception("Failed to build structured prompt for conversation %s", conv_id)

            # prompt
            history = self._build_history(
                agent_ids,
                conv.messages,
                documents_map=documents_map,
            )

            last_merged_context = {}
            for msg in reversed(conv.messages):
                msg_data = msg if isinstance(msg, dict) else msg.model_dump(mode="json")
                if msg_data.get("role") != "assistant":
                    continue
                candidate = msg_data.get("merged_context")
                if isinstance(candidate, dict):
                    last_merged_context = candidate
                break
            # assistant streaming
            final_reply, artifacts_meta, tool_validation_requests, memory_fragment, context_upsert = self._stream_reply(
                conv_id, sel_agents, history, emit, events, cancel_event, trace, last_merged_context
            )
            # Compute merged context for this turn
            base_context = {}
            try:
                base_context = get_agent_context(g.get("authIdentifier"), conv_id)
            except Exception:
                logger.exception("Failed to build base agent context for conversation %s", conv_id)
            merged_context = {
                **(last_merged_context or {}),
                **(base_context or {}),
                **(context_upsert or {}),
            }
            # TODO maybe could optimize and return these directly instead of recomputing them
            actions = build_actions(events)
            used_ids = get_used_agent_ids(events)
            assistant_msg_dict = schemas.MessageCreate(
                id=str(uuid.uuid4()),
                conversation_id=conv_id,
                role="assistant",
                content=final_reply,
                event_log=events,
                actions=actions,
                artifacts=artifacts_meta,
                tool_validation_requests=tool_validation_requests,
                memory_fragment=memory_fragment,
                selected_agent_ids=sel_ids,
                used_agent_ids=used_ids,
                llm_id=selected_llm,
                agents_enabled=agents_enabled,
                merged_context=merged_context or None,
            )

            # Add inputs and outputs to trace
            # Inputs: only the last user message (current query)
            trace.inputs["messages"] = [{"role": "user", "text": user_msg}]
            # Outputs: the generated assistant response
            trace.outputs["text"] = final_reply

            conv.messages.append(assistant_msg_dict)
            # atomic write
            self.store.append_messages(conv_id, [user_msg_dict, assistant_msg_dict])

            # Check if title needs to be generated (but don't generate it yet - do it async after chat_end)
            needs_title_generation = not conv.title or (conv.title.strip() == "Untitled")

            # Update meta for agent/llm changes immediately (title will be updated async if needed)
            if (
                old_agent_ids != agent_ids
                or conv.agents_enabled != agents_enabled
                or conv.llm_id != selected_llm
            ):
                self.store.update_conversation_meta(
                    conv_id,
                    title=conv.title,
                    agent_ids=agent_ids,
                    agents_enabled=agents_enabled,
                    llm_id=selected_llm,
                )
        finally:
            # End tracing
            trace.end(int(time.time() * 1000))

            # If we have an assistant message, update it with the trace info
            if assistant_msg_dict is not None:
                # Add the trace data to the assistant message dictionary.
                assistant_msg_dict.trace = trace.to_dict() if trace else {}
                # Update the message in the store using the update_message function.
                self.store.update_message(assistant_msg_dict.id, {"trace": assistant_msg_dict.trace})

            # Emit chat_end FIRST to unblock the frontend immediately
            # Title generation will happen async afterwards
            if conv_id and not cancel_event.is_set() and not draft_mode:
                # Reload conversation from database to get latest messages with extraction_mode (including quota_exceeded)
                # This ensures extraction_mode JSON is parsed and included in messages sent to frontend
                fresh_conv = self.store.get_conversation(conv_id) or conv
                # Get safe values for chat_end payload
                title = fresh_conv.title if fresh_conv else "New Conversation"
                messages = fresh_conv.messages if fresh_conv else []

                emit(
                    "chat_end",
                    {
                        "agentIds": agent_ids or [],
                        "conversationId": conv_id,
                        "title": title,
                        "messages": [msg.model_dump(mode="json", by_alias=True) for msg in messages],
                        "modeAgents": agents_enabled if agents_enabled is not None else True,
                        "selectedLLM": selected_llm or "",
                        "hasEventLog": bool(events),
                    },
                )

                # Generate title AFTER chat_end is sent (async from frontend's perspective)
                # This unblocks the user from sending more messages while title is being generated
                if needs_title_generation:
                    try:
                        generated_title = self._generate_title(conv.messages, trace=trace)
                        conv.title = generated_title
                        # Update title in database
                        self.store.update_conversation_meta(
                            conv_id,
                            title=generated_title,
                            agent_ids=agent_ids,
                            agents_enabled=agents_enabled,
                            llm_id=selected_llm,
                        )
                        # Emit title update event so frontend can update the UI
                        emit(
                            "chat_title_updated",
                            {
                                "conversationId": conv_id,
                                "title": generated_title,
                            },
                        )
                    except Exception as e:
                        # If title generation fails (e.g., broken LLM connection), keep existing title
                        logger.exception(f"Failed to generate title for conversation {conv_id}: {e}")

    # ------------------------------------------------------------------ #
    #  Draft-mode helper (legacy logic, no optimisation)
    # ------------------------------------------------------------------ #
    def _stream_draft(
        self, conv_id: str, agent_ids: list[str], user_msg: str, emit, cancel_event: threading.Event, trace,
        tool_validation_payload: dict = None,
    ):
        """
        Keeps the original behaviour for draft chats: full conversation
        object rewritten on every turn.
        """
        logger.debug("_stream_draft: conv=%s agents=%s user_msg=%r", conv_id, agent_ids, user_msg)
        conv = self.store.get_draft_conversation(agent_ids[0])
        if not conv:
            conv = self._blank_conversation(conv_id, agent_ids).model_dump(mode="json")
        else:
            conv = conv.convo
        if tool_validation_payload:
            conv_data = conv
            last_msg = conv_data["messages"][-1] if conv_data["messages"] else None
            if not last_msg or not last_msg.get("toolValidationRequests"):
                emit("chat_error", {"error": "No tool validation requests found in last message"})
                return

            new_responses = [tool_validation_payload.get("toolValidationResponse")]
            current_responses = last_msg.get("toolValidationResponses", [])
            all_tool_responses = current_responses + new_responses
            total_requests = last_msg.get("toolValidationRequests", [])

            if len(all_tool_responses) < len(total_requests):
                logger.info(
                    "Not all tool validation responses received yet (%d/%d), updating message and waiting for more",
                    len(all_tool_responses),
                    len(total_requests),
                )
                last_msg["toolValidationResponses"] = all_tool_responses
                self.store.upsert_draft_conversation(agent_ids[0], conv_data)
                return

            last_msg["toolValidationResponses"] = all_tool_responses
            history = self._build_history(agent_ids, conv_data.get("messages", []))
            history.append(get_memory_fragment_msg(last_msg.get("memoryFragment") or {}))
            history.append(get_tool_validation_requests_msg(last_msg.get("toolValidationRequests") or {}))
            history.append(get_tool_validation_responses_msg(all_tool_responses))
            last_user_message = ""
            for msg in reversed(conv_data.get("messages", [])):
                if msg.get("role") == "user":
                    last_user_message = msg.get("content", "")
                    break
            user_msg = last_user_message
        else:
            user_msg_id = str(uuid.uuid4())
            conv["messages"].append({"id": user_msg_id, "role": "user", "content": user_msg, "hasEventLog": False})

            history = self._build_history(agent_ids, conv.get("messages", []))

        events: list[dict] = []
        final_reply, artifacts_meta, tool_validation_requests, memory_fragment, _context_upsert = self._stream_reply(
            conv_id,
            [
                {"agentId": aid, "query": user_msg, "agentName": self._user_agents_dict.get(str(aid), {}).name}
                for aid in agent_ids
            ],
            history,
            emit,
            events,
            cancel_event,
            trace,
            # TODO (clement): Currently we don't handle context upsert in draft mode.
            # Probably not a big issue since we don't build structured agents as Quick Agents
            {},
        )
        # TODO maybe could optimize and return these directly instead of recomputing them
        actions = build_actions(events)
        used_ids = get_used_agent_ids(events)
        conv["messages"].append(
            {
                "id": str(uuid.uuid4()),
                "role": "assistant",
                "content": final_reply,
                "eventLog": events,
                "actions": actions,
                "artifactsMetadata": artifacts_meta,
                "toolValidationRequests": tool_validation_requests,
                "memoryFragment": memory_fragment,
                "selectedAgentIds": agent_ids,
                "usedAgentIds": used_ids,
                "trace": trace.to_dict() if trace else {},
            }
        )
        self.store.upsert_draft_conversation(agent_ids[0], conv)

        if not cancel_event.is_set():
            emit(
                "chat_end",
                {
                    "agentIds": agent_ids,
                    "conversationId": conv_id,
                    "title": conv["title"],
                    "messages": conv["messages"],
                    "hasEventLog": bool(events),
                },
            )

    # ---------- streaming low-level -----------------------------------
    def _stream_reply(
        self, convId, sel_agents, history, emit, events, cancel_event: threading.Event, trace, last_merged_context=None
    ):
        final = ""
        artifacts_meta = {}
        tool_validation_requests = []
        memory_fragment = {}
        generic_error = "Error during agent execution"
        context_upsert = {}

        def token_cb(tok):
            nonlocal final
            final += tok
            emit("chat_token", {"token": tok, "conversationId": convId})

        def push_event(ev_dict: dict, store_event: bool = True):
            ev_dict.update({"conversationId": convId})
            if store_event:
                events.append(ev_dict)
            emit("chat_event", ev_dict)

        sel_ids = [a.get("agentId") for a in sel_agents]
        logger.info(f"_stream_reply  called conv: {convId}, sel_ids: {sel_ids}")
        base_context = get_agent_context(g.get("authIdentifier"), convId)
        merged_context = {
            **(last_merged_context or {}),
            **(base_context or {}),
        }
        if len(sel_ids) == 0:
            final, context_upsert = self._stream_base_llm(history, merged_context, token_cb, push_event, cancel_event, None, None, trace)
        elif len(sel_ids) == 1:
            final, context_upsert = self._stream_single(
                sel_agents[0],
                history,
                merged_context,
                token_cb,
                push_event,
                artifacts_meta,
                tool_validation_requests,
                memory_fragment,
                cancel_event,
                convId,
                trace,
            )
        else:
            if agents_as_tools():
                final, tool_validation_requests, memory_fragment, context_upsert = self._stream_agent_connect(
                    sel_ids,
                    history,
                    merged_context,
                    token_cb,
                    push_event,
                    artifacts_meta,
                    cancel_event,
                    convId,
                    trace,
                )
            else:
                context = self._get_agent_context_safe(
                    g.get("authIdentifier"),
                    convId,
                    pcb=push_event,
                    agent_name="Agent Hub",
                )
                if context is None:
                    token_cb(generic_error)
                    return generic_error, artifacts_meta, tool_validation_requests, memory_fragment
                final, tool_validation_requests, memory_fragment = OrchestratorService.stream_multiple_agents(
                    llm_id=get_user_base_llm(self.store),
                    sel_agents=sel_agents,
                    messages=history,
                    context=merged_context,
                    tcb=token_cb,
                    pcb=push_event,
                    artifacts_meta=artifacts_meta,
                    cancel_event=cancel_event,
                    store=self.store,
                    trace=trace,  # Pass trace to orchestrator
                )
        return final, artifacts_meta, tool_validation_requests, memory_fragment, context_upsert

    # ---- concrete streamers -----------------------------------------
    def _stream_base_llm(self, msgs, context : Dict, tcb, pcb, cancel_event: threading.Event, emit=None, conv_id=None, trace=None):
        generic_error = "Error during llm call"
        final = ""
        context_upsert = {}
        try:
            # Check if we should use Vision LLM for screenshots
            # Determine if any messages have documents with snapshots
            has_snapshots = False
            for message in msgs:
                docs = message.get("__documents__") or []
                for doc in docs:
                    snapshots = doc.get("snapshots") or []
                    if snapshots:
                        has_snapshots = True
                        break
                if has_snapshots:
                    break

            # Use Vision LLM if extraction mode is pagesScreenshots and there are snapshots
            extraction_mode = get_extraction_mode()
            if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS and has_snapshots:
                llm_id = get_conversation_vision_llm()
            else:
                llm_id = get_user_base_llm(self.store)

            comp = self.project.get_llm(llm_id).new_completion()
            folder = dataiku.Folder(get_uploads_managedfolder_id())
            inline_cache: Dict[str, str] = {}

            def append_documents(documents: list[dict]) -> None:
                nonlocal comp
                for doc in documents:
                    doc_name = doc.get("name") or "Document"
                    text = (doc.get("text") or "").strip()
                    snapshots = doc.get("snapshots") or []
                    # If text content exists, prefer a single system message with the content
                    if text:
                        comp = comp.with_message(f"[Document: {doc_name}]\n{text}", role="user")
                        continue
                    # If a single image snapshot, add filename then the image once
                    if len(snapshots) == 1:
                        snap = snapshots[0]
                        path = snap.get("screenshot_path")
                        inline_data = None
                        if path:
                            inline_data = inline_cache.get(path)
                            if inline_data is None:
                                try:
                                    with folder.get_download_stream(path) as stream:
                                        image_bytes = stream.read()
                                    inline_data = base64.b64encode(image_bytes).decode("utf-8")
                                    inline_cache[path] = inline_data
                                except Exception as err:  # pragma: no cover
                                    logger.warning("Failed to retrieve snapshot %s: %s", path, err)
                        msg = comp.new_multipart_message()
                        msg.with_text(f"[Image: {doc_name}]")
                        if inline_data:
                            msg.with_inline_image(inline_data)
                        msg.add()
                        continue
                    # Default: multi-page (pdf/docx/pptx) → per-page captions and images
                    if snapshots:
                        msg = comp.new_multipart_message()
                        has_parts = False
                        for idx, snap in enumerate(snapshots, start=1):
                            caption = f"Document {doc_name}, page {snap.get('page') or idx}"
                            msg.with_text(caption)
                            path = snap.get("screenshot_path")
                            inline_data = None
                            if path:
                                inline_data = inline_cache.get(path)
                                if inline_data is None:
                                    try:
                                        with folder.get_download_stream(path) as stream:
                                            image_bytes = stream.read()
                                        inline_data = base64.b64encode(image_bytes).decode("utf-8")
                                        inline_cache[path] = inline_data
                                    except Exception as err:  # pragma: no cover
                                        logger.warning("Failed to retrieve snapshot %s: %s", path, err)
                            if inline_data:
                                msg.with_inline_image(inline_data)
                            has_parts = True
                        if has_parts:
                            msg.add()

            for message in msgs:
                comp = comp.with_message(message.get("content", ""), role=message.get("role"))
                docs = message.get("__documents__") or []
                if docs:
                    append_documents(docs)

            if context:
                comp = comp.with_context(context)

            from backend.utils.logging_utils import sanitize_messages_for_log

            logger.info("Streaming using plain LLM - prompt messages: %s", sanitize_messages_for_log(msgs))
            # Emit thinking event before starting stream (this will be stored in events list via pcb)
            pcb({"eventKind": EventKind.AGENT_THINKING, "eventData": {"agentName": "Agent Hub"}})
            # Track if we've emitted the responding event (only on first text token)
            responding_emitted = False

            # wrap in closing(...) so that .close() (and thus HTTP teardown) happens on break
            with closing(comp.execute_streamed()) as stream:
                for chunk in stream:
                    if cancel_event.is_set():
                        break
                    if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                        trace_data = chunk.trace
                        if trace_data and trace is not None:
                            trace.append_trace(trace_data)

                        if chunk.data.get("contextUpsert") and isinstance(chunk.data["contextUpsert"], dict):
                            context_upsert.update(chunk.data["contextUpsert"])

                        continue
                    data = chunk.data
                    if "text" in data:
                        # Emit responding event on first text token (replaces thinking)
                        # This must happen BEFORE token callback to ensure event is processed first
                        if not responding_emitted:
                            # Emit ANSWER_STREAM_START event (will be stored via pcb)
                            pcb({"eventKind": EventKind.ANSWER_STREAM_START, "eventData": {"agentName": "Agent Hub"}})
                            responding_emitted = True
                            logger.info("Emitted ANSWER_STREAM_START event for Agent Hub")
                        tcb(data["text"])
                    elif data.get("type") == "event":
                        pcb({"eventKind": data["eventKind"], "eventData": data.get("eventData", {})})
                    if "text" in data:
                        final += data["text"]
        except Exception as e:
            # Catch any exceptions during streaming and emit error message as text and error event
            error_message = f"{generic_error}: {extract_error_message(str(e))}"
            logger.exception(f"Exception in _stream_base_llm: {error_message}")
            # Emit error message as text so it appears in the assistant response
            tcb(generic_error)
            # Emit BASE_LLM_ERROR event for the event log (similar to how user_agent emits AGENT_ERROR)
            pcb(
                {
                    "eventKind": EventKind.BASE_LLM_ERROR,
                    "eventData": {
                        "message": error_message,
                        "agentName": "Agent Hub",
                    },
                }
            )
            final = generic_error
        return final, context_upsert

    def _stream_single(
        self,
        agent_info,
        msgs,
        context: Dict,
        tcb,
        pcb,
        artifacts_meta,
        tool_validation_requests,
        memory_fragment,
        cancel_event: threading.Event,
        conv_id,
        trace,
    ):
        final = ""
        generic_error = "Error during agent execution"
        final = ""
        context_upsert = {}
        aid = agent_info.get("agentId")
        aname = agent_info.get("agentName")
        ua = self._user_agents_dict.get(str(aid))
        user = g.get("authIdentifier")
        agent_query = agent_info.get("query", "")
        if ua:
            logger.info(f"Calling user agent {aid} {aname}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            is_owner = ua.owner == user

            # Same simple logic
            if self.draft_mode and is_owner:
                use_published = False
            else:
                if not ua.published_version:
                    tcb("This agent has not been published yet.")
                    return "This agent has not been published yet."
                use_published = True

            # Check guardrails for agent documents at runtime
            # This re-validates documents in case the guardrails pattern has changed
            guardrails_result = self._enforce_guardrails([aid], pcb)
            if guardrails_result["blocked_agents"]:
                error_msg = self._emit_guardrails_block_events(guardrails_result["blocked_agents"], pcb, tcb)
                return error_msg

            ua_obj = UserAgent(ua, use_published=use_published)
            from dataiku.llm.tracing import new_trace

            async def _run():
                nonlocal final
                nonlocal context_upsert
                stream_id = str(uuid.uuid4())
                context = self._get_agent_context_safe(user, conv_id, pcb=pcb, agent_id=aid, agent_name=aname)
                if context is None:
                    tcb(generic_error)
                    final = generic_error
                    return
                # mark usage up-front (we know exactly which agent is called)
                async for ev in ua_obj.aprocess_stream(
                    query={"messages": msgs, "context": context},
                    settings={},
                    trace=trace or new_trace(aid),
                ):
                    if cancel_event.is_set():
                        break
                    footer = ev.get("footer") or {}
                    upsert = footer.get("contextUpsert")
                    if upsert and isinstance(upsert, dict):
                        context_upsert.update(upsert)
                    normalise_stream_event(
                        ev=ev,
                        tcb=tcb,
                        pcb=pcb,
                        msgs=msgs,
                        aid=aid,
                        trace=trace,
                        aname=aname,
                        query=agent_query,
                        artifacts_meta=artifacts_meta,
                        tool_validation_requests=tool_validation_requests,
                        memory_fragment=memory_fragment,
                        stream_id=stream_id,
                    )
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]

            asyncio.run(_run())
        else:
            logger.info(f"Calling Enterprise Agent: {aid}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            stream_id = str(uuid.uuid4())
            for ev in call_dss_agent_full_conversation(aid, aname, msgs, user, True, conv_id=conv_id, trace=trace, context=context):
                if cancel_event.is_set():
                    break
                footer = ev.get("footer") or {}
                upsert = footer.get("contextUpsert")
                if upsert and isinstance(upsert, dict):
                    context_upsert.update(upsert)
                normalise_stream_event(
                    ev=ev,
                    tcb=tcb,
                    pcb=pcb,
                    msgs=msgs,
                    aid=aid,
                    trace=trace,
                    aname=aname,
                    query=agent_query,
                    artifacts_meta=artifacts_meta,
                    tool_validation_requests=tool_validation_requests,
                    memory_fragment=memory_fragment,
                    stream_id=stream_id,
                )
                if "chunk" in ev and "text" in ev["chunk"]:
                    final += ev["chunk"]["text"]
        return final, context_upsert

    def _stream_agent_connect(
        self,
        aids,
        msgs,
        context: Dict,
        tcb,
        pcb,
        artifacts_meta,
        cancel_event: threading.Event,
        conv_id,
        trace,
    ):
        final = ""
        tool_validation_reqs = []
        mem_frag = {}
        context_upsert ={}
        from dataiku.llm.tracing import new_trace

        user = g.get("authIdentifier")
        generic_error = "Error during agent execution"

        # Check guardrails for all user agents with documents before starting
        # In multi-agent mode, we filter out violating agents instead of blocking entirely
        guardrails_result = self._enforce_guardrails(aids, pcb)
        blocked_agent_ids = guardrails_result["blocked_agent_ids"]
        blocked_agents = guardrails_result["blocked_agents"]

        if blocked_agents:
            # Filter out blocked agents
            filtered_aids = [aid for aid in aids if aid not in blocked_agent_ids]

            if not filtered_aids:
                # All agents are blocked - emit block events and return error
                error_msg = self._emit_guardrails_block_events(blocked_agents, pcb, tcb)
                return error_msg, [], {}

            # Some agents remain - emit filter events and continue with remaining agents
            emit_guardrails_filter_events(blocked_agents, pcb)
            aids = filtered_aids

        agent_connect = build_agent_connect(
            self.store, aids, user_agents=self._user_agents_dict, draft_mode=self.draft_mode, conv_id=conv_id
        )

        def _run():
            nonlocal final
            nonlocal trace
            if not trace:
                trace = new_trace("DKU_AGENT_HUB_QUERY")
                trace.begin(int(time.time() * 1000))
            stream_id = str(uuid.uuid4())
            try:
                for ev in agent_connect.process_stream(
                    query={"messages": msgs, "context": context},
                    settings={},
                    artifacts_meta=artifacts_meta,
                    trace=trace,
                    pcb=pcb,
                ):
                    if cancel_event.is_set():
                        break
                    footer = ev.get("footer") or {}
                    upsert = footer.get("contextUpsert")
                    if upsert and isinstance(upsert, dict):
                        context_upsert.update(upsert)
                    normalise_stream_event(
                        ev=ev,
                        tcb=tcb,
                        pcb=pcb,
                        msgs=msgs,
                        stream_id=stream_id,
                        artifacts_meta=artifacts_meta,
                        tool_validation_requests=tool_validation_reqs,
                        memory_fragment=mem_frag,
                        aname="Agent Hub"

                    )
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]
            except Exception as e:
                logger.exception("Error in _stream_agent_connect for conversation %s: %s", conv_id, str(e))
                if not final:
                    tcb(generic_error)
                    final = generic_error
            finally:
                if trace:
                    trace.end(int(time.time() * 1000))

        _run()
        return final, tool_validation_reqs, mem_frag, context_upsert

    # ------------------------------------------------------------------ #
    #  Utility helpers
    # ------------------------------------------------------------------ #
    def _get_agent_context_safe(
        self,
        user: str,
        conv_id: str,
        *,
        pcb=None,
        agent_id: str | None = None,
        agent_name: str | None = None,
    ) -> dict | None:
        try:
            return get_agent_context(user, conv_id)
        except Exception as exc:
            error_message = f"Failed to build agent context: {extract_error_message(str(exc))}"
            logger.exception("Failed to build agent context for conversation %s", conv_id)
            if pcb:
                pcb(
                    {
                        "eventKind": EventKind.AGENT_ERROR,
                        "eventData": {
                            "message": error_message,
                            "agentId": agent_id or "",
                            "agentName": agent_name or "Agent Hub",
                        },
                    }
                )
            return None

    def _filter_active_agents(self, agent_ids: list[str], allow_draft: bool = False) -> list[str]:
        """Filter to only usable agents based on context."""
        current_user = g.get("authIdentifier")
        enterprise_ids = {e["id"] for e in get_enterprise_agents(current_user)}

        filtered = []
        for aid in agent_ids:
            if aid in enterprise_ids:
                # Enterprise agents always OK
                filtered.append(aid)
            elif aid in self._user_agents_dict:
                agent = self._user_agents_dict[aid]
                is_owner = agent.owner == current_user
                if allow_draft and is_owner:
                    # Owner in draft mode - allow
                    filtered.append(aid)
                elif agent.published_version:
                    # Has published version - allow
                    filtered.append(aid)
                # else: skip - no published version

        return filtered

    def _process_conversation_guardrails(self, conv_id: str, emit, events: list[dict]) -> None:
        """
        Checks all conversation documents against guardrails.
        Emits events for checks, results, and violations.
        """
        if not (get_guardrails_enabled() and get_guardrails_pattern()):
            return

        # Re-check ALL conversation documents (not just new attachments) because:
        # 1. The guardrails pattern may have changed, making previously-failed docs valid
        # 2. The cache in guardrails_service ensures efficiency when patterns haven't changed
        all_conv_docs = self.store.get_derived_documents(conv_id) or []
        all_attachments_for_guardrails = [
            {
                "document_path": doc.document_path,
                "document_name": doc.document_name,
            }
            for doc in all_conv_docs
        ]

        if not all_attachments_for_guardrails:
            return

        try:
            count = len(all_attachments_for_guardrails)
            doc_word = "document" if count == 1 else "documents"
            evt_guardrails_checks = {
                "eventKind": EventKind.GUARDRAILS_CHECKS,
                "eventData": {
                    "message": f"Applying guardrails checks to {count} {doc_word}.",
                },
            }
            emit("chat_event", {**evt_guardrails_checks, "conversationId": conv_id})
            events.append(evt_guardrails_checks)

            guardrails_result = process_documents_for_guardrails(
                self.store,
                conv_id,
                all_attachments_for_guardrails,
            )
            content_violations = guardrails_result.get("content_violations", [])
            extraction_failures = guardrails_result.get("extraction_failures", [])
            all_checked_docs = guardrails_result.get("all_checked", [])

            # Always emit guardrails_checked event with all document statuses for frontend updates
            if all_checked_docs:
                evt_guardrails_checked = {
                    "eventKind": EventKind.GUARDRAILS_CHECKED,
                    "eventData": {
                        "documents": all_checked_docs,
                    },
                }
                emit("chat_event", {**evt_guardrails_checked, "conversationId": conv_id})
                # Don't append to events - this is just for live frontend updates

            # Emit separate events for content violations and extraction failures
            if content_violations:
                violation_names = [d.get("name", "Unknown") for d in content_violations]
                evt_guardrails_violation = {
                    "eventKind": EventKind.GUARDRAILS_VIOLATION,
                    "eventData": {
                        "failed_documents": violation_names,
                        "message": f"Documents filtered due to content restrictions: {', '.join(violation_names)}",
                    },
                }
                emit("chat_event", {**evt_guardrails_violation, "conversationId": conv_id})
                events.append(evt_guardrails_violation)
                logger.warning(
                    f"Guardrails filtered {len(content_violations)} documents for conversation {conv_id}: {violation_names}"
                )

            if extraction_failures:
                extraction_failure_names = [d.get("name", "Unknown") for d in extraction_failures]
                evt_extraction_failed = {
                    "eventKind": EventKind.GUARDRAILS_EXTRACTION_FAILED,
                    "eventData": {
                        "failed_documents": extraction_failure_names,
                        "message": f"Documents excluded due to text extraction failure: {', '.join(extraction_failure_names)}",
                    },
                }
                emit("chat_event", {**evt_extraction_failed, "conversationId": conv_id})
                events.append(evt_extraction_failed)
                logger.warning(
                    f"Text extraction failed for {len(extraction_failures)} documents in conversation {conv_id}: {extraction_failure_names}"
                )
        except Exception:
            logger.exception("Failed to process guardrails for conversation %s", conv_id)

    def _enforce_guardrails(self, agent_ids: list[str], pcb) -> dict:
        """
        Validates that the specified agents do not violate guardrails.

        Returns a dict with:
            - blocked_agent_ids: set of agent IDs that have violations
            - blocked_agents: list of dicts with agent info and violation details

        The caller is responsible for deciding whether to block or filter,
        and for emitting appropriate events.
        """
        result = {"blocked_agent_ids": set(), "blocked_agents": []}

        if not (get_guardrails_enabled() and get_guardrails_pattern()):
            return result

        user = g.get("authIdentifier")

        for aid in agent_ids:
            ua = self._user_agents_dict.get(str(aid))
            # Only user agents have documents to check here; Enterprise agents are skipped or have no ua entry
            if not ua:
                continue

            is_owner = ua.owner == user
            # Standard logic: Use published unless in draft mode AND owner
            use_published = not (self.draft_mode and is_owner)

            docs_to_check = []
            if use_published:
                if ua.published_version:
                    docs_to_check = ua.published_version.documents or []
            else:
                docs_to_check = ua.documents or []

            if not docs_to_check:
                continue

            check_result = self._check_agent_runtime_guardrails(aid, docs_to_check, use_published, pcb)
            if check_result.get("has_violations"):
                content_violations = check_result.get("content_violations", [])
                extraction_failures = check_result.get("extraction_failures", [])
                result["blocked_agent_ids"].add(aid)
                result["blocked_agents"].append(
                    {
                        "agentId": aid,
                        "agentName": ua.name,
                        "content_violations": [d.get("name") for d in content_violations],
                        "extraction_failures": [d.get("name") for d in extraction_failures],
                    }
                )

        return result

    def _emit_guardrails_block_events(self, blocked_agents: list[dict], pcb, tcb=None) -> str:
        """
        Emit guardrails violation/extraction failure events for blocking scenario.
        Returns the combined error message.
        """
        if len(blocked_agents) == 1:
            # Single agent: use single-agent event structure
            ba = blocked_agents[0]
            content_violations = ba["content_violations"]
            extraction_failures = ba["extraction_failures"]
            messages = []

            if content_violations:
                msg = f"Cannot use agent '{ba['agentName']}': some knowledge base documents are blocked by content policy: {', '.join(content_violations)}"
                messages.append(msg)
                evt = {
                    "eventKind": EventKind.GUARDRAILS_VIOLATION,
                    "eventData": {
                        "failed_documents": content_violations,
                        "message": msg,
                        "agentId": ba["agentId"],
                        "agentName": ba["agentName"],
                    },
                }
                pcb(evt)

            if extraction_failures:
                msg = f"Cannot use agent '{ba['agentName']}': text extraction failed for some knowledge base documents: {', '.join(extraction_failures)}"
                messages.append(msg)
                evt = {
                    "eventKind": EventKind.GUARDRAILS_EXTRACTION_FAILED,
                    "eventData": {
                        "failed_documents": extraction_failures,
                        "message": msg,
                        "agentId": ba["agentId"],
                        "agentName": ba["agentName"],
                    },
                }
                pcb(evt)

            combined_msg = " ".join(messages)
            if tcb:
                tcb(combined_msg)
            return combined_msg
        else:
            # Multi-agent event structure
            content_violation_parts = []
            extraction_failure_parts = []

            for ba in blocked_agents:
                agent_name = ba["agentName"]
                content_violations = ba.get("content_violations", [])
                extraction_failures = ba.get("extraction_failures", [])

                if content_violations:
                    content_violation_parts.append(f"'{agent_name}' ({', '.join(content_violations)})")
                if extraction_failures:
                    extraction_failure_parts.append(f"'{agent_name}' ({', '.join(extraction_failures)})")

            messages = []

            # Emit event for content violations if any
            if content_violation_parts:
                msg = f"Cannot proceed: some agents have knowledge base documents blocked by content policy: {'; '.join(content_violation_parts)}"
                messages.append(msg)
                evt = {
                    "eventKind": EventKind.GUARDRAILS_VIOLATION,
                    "eventData": {
                        "blocked_agents": blocked_agents,
                        "message": msg,
                    },
                }
                pcb(evt)

            # Emit event for extraction failures if any
            if extraction_failure_parts:
                msg = f"Cannot proceed: text extraction failed for knowledge base documents in some agents: {'; '.join(extraction_failure_parts)}"
                messages.append(msg)
                evt = {
                    "eventKind": EventKind.GUARDRAILS_EXTRACTION_FAILED,
                    "eventData": {
                        "blocked_agents": blocked_agents,
                        "message": msg,
                    },
                }
                pcb(evt)

            combined_msg = " ".join(messages)
            if tcb:
                tcb(combined_msg)
            return combined_msg

    def _check_agent_runtime_guardrails(
        self,
        agent_id: str,
        documents: list,
        use_published: bool,
        pcb,
    ) -> dict:
        """
        Check agent documents against guardrails at runtime.

        This re-validates documents in case the guardrails pattern has changed
        since the documents were indexed.

        Args:
            agent_id: The agent ID
            documents: List of document dicts
            use_published: Whether to use published zone
            pcb: Push callback for events

        Returns:
            Dict with 'content_violations', 'extraction_failures', 'passed', and 'has_violations' keys
        """
        from backend.services.guardrails_service import check_agent_guardrails_at_runtime

        try:
            # Emit event that we're checking guardrails
            active_docs = [d for d in documents if d.get("active") and not d.get("deletePending")]
            if active_docs:
                pcb(
                    {
                        "eventKind": EventKind.GUARDRAILS_CHECKS,
                        "eventData": {
                            "message": f"Checking {len(active_docs)} knowledge base document(s) against content policy.",
                        },
                    },
                    store_event=True,
                )

            result = check_agent_guardrails_at_runtime(agent_id, documents, use_published)

            if result.get("has_violations"):
                content_violations = len(result.get("content_violations", []))
                extraction_failures = len(result.get("extraction_failures", []))
                total_blocked = content_violations + extraction_failures
                logger.warning(
                    f"Runtime guardrails check failed for agent {agent_id}: "
                    f"{total_blocked} documents blocked ({content_violations} content violations, {extraction_failures} extraction failures)"
                )

            return result
        except Exception as e:
            logger.exception(f"Error during runtime guardrails check for agent {agent_id}: {e}")
            # On error, allow the agent to proceed (fail open)
            return {"content_violations": [], "extraction_failures": [], "passed": documents, "has_violations": False}

    @staticmethod
    def _clean_title(raw: str) -> str:
        s = (raw or "").strip()
        s = re.sub(r"```[\w\-]*\s*([\s\S]*?)\s*```", r"\1", s).strip()

        for line in s.splitlines():
            if line.strip():
                s = line.strip()
                break
        else:
            s = ""

        s = re.sub(r"^\s*#{1,6}\s*", "", s)  # '#', '##', etc.
        s = re.sub(r"^\s*>+\s*", "", s)  # '>' or '>>'

        s = re.sub(
            r"^\s*(?:title|subject|chat|conversation)\s*[:：\-–]\s*",
            "",
            s,
            flags=re.IGNORECASE,
        )

        s = s.strip(" \"'`*_")

        s = re.sub(r"\s+", " ", s)

        s = re.sub(r"[.:;!?，。；！]+$", "", s)

        return s or "Untitled"
