import asyncio
import datetime
import json
import re
import threading
import time
import uuid
from contextlib import closing
from typing import List

import dataiku
from flask import g, request

from backend.agents.user_agent import UserAgent
from backend.config import agents_as_tools, get_enterprise_agents, get_global_system_prompt
from backend.models.events import EventKind
from backend.services.orchestrator_service import OrchestratorService
from backend.utils.conv_utils import normalise_stream_event
from backend.utils.events_utils import build_actions, get_used_agent_ids
from backend.utils.logging_utils import get_logger
from backend.utils.user_utils import get_agent_context
from backend.utils.utils import (
    build_agent_connect,
    call_dss_agent_full_conversation,
    get_user_base_llm,
    select_agents,
)

logger = get_logger(__name__)


class ConversationError(Exception):
    """Raised when payload validation fails or other business rules break."""


class ConversationService:
    def __init__(self, store, draft_mode: bool = False):
        self.store = store
        self.draft_mode = draft_mode  # True only in test contexts
        self.project = dataiku.api_client().get_default_project()
        self._user_agents_list = self.store.get_all_agents()
        self._user_agents_dict = {a["id"]: a for a in self._user_agents_list}

    def _missing_agents(self, agent_ids: List[str], user: str) -> List[str]:
        existing = {e["id"] for e in get_enterprise_agents(user)}
        existing.update(self._user_agents_dict.keys())
        return [aid for aid in agent_ids if aid not in existing]

    @staticmethod
    def _blank_conversation(conv_id: str, agent_ids: List[str]) -> dict:
        return {"id": conv_id, "title": None, "messages": [], "agentIds": agent_ids}

    def _generate_title(self, messages: list[dict], trace=None) -> str:
        # Create a subspan if trace is provided
        if trace:
            title_span = trace.subspan("generate_title")
            title_span.begin(int(time.time() * 1000))
        else:
            title_span = None

        today = datetime.date.today().strftime("%Y-%m-%d")
        llm_id = get_user_base_llm(self.store)

        # Add inputs to trace
        if title_span:
            title_span.inputs["llm_id"] = llm_id
            title_span.inputs["message_count"] = len(messages)
            title_span.inputs["date"] = today

        system_msg = (
            "You are a titling assistant. "
            f"Today is {today}. "
            "Return a single short title (≤7 words) that summarizes the conversation. "
            "Plain text ONLY: no prefixes (e.g., 'Title:'), no quotes, no emojis, "
            "no code fences, no extra lines, and no trailing punctuation."
        )
        user_msg = "Generate a short title for this conversation. Output only the title text."

        comp = self.project.get_llm(llm_id).new_completion()
        comp.with_message(system_msg, role="system")
        comp.with_message(user_msg, role="user")

        for m in messages:
            comp = comp.with_message(m["content"], role=m["role"])

        logger.info(
            "Generating conversation title:\nLLM id=[%s]\nCompletion Query=%s\nCompletion Settings=%s\n",
            llm_id,
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        resp = comp.execute()
        raw_title = resp.text or ""
        logger.info(
            "Conversation title LLM response:\nLLM id=[%s]\nCompletion Response=%s",
            llm_id,
            raw_title,
        )

        if title_span and resp.trace:
            title_span.append_trace(resp.trace)

        cleaned_title = self._clean_title(raw_title)

        # Add outputs to trace
        if title_span:
            title_span.outputs["raw_title"] = raw_title
            title_span.outputs["cleaned_title"] = cleaned_title
            title_span.end(int(time.time() * 1000))
        return cleaned_title

    def _build_history(self, agent_ids: list[str], conv_messages: list[dict]) -> list[dict]:
        """
        • If exactly ONE *user* agent → **no global** prompt (agent will add its own).
        • Otherwise (0 or multiple agents)  → prepend global system prompt.
        """
        if len(agent_ids) == 1 and any(a["id"] == agent_ids[0] for a in self.store.get_all_agents()):
            return list(conv_messages)  # no global prompt

        today = datetime.date.today().strftime("%Y-%m-%d")
        global_prompt = get_global_system_prompt()
        sys_msg = {"role": "system", "content": f"Today is {today}.{global_prompt if len(agent_ids) == 0 else ''}"}
        return [sys_msg] + list(conv_messages)

    # ------------------------------------------------------------------ #
    #  Streaming send  (WebSocket)
    # ------------------------------------------------------------------ #
    def stream_message(self, payload: dict, emit, cancel_event: threading.Event):
        logger.info("stream_message called from %s with Payload: %r", request.remote_addr, payload)
        from dataiku.llm.tracing import new_trace

        trace = new_trace("DKU_AGENT_HUB_QUERY")
        trace.begin(int(time.time() * 1000))

        assistant_msg_dict = None
        try:
            draft_mode = bool(payload.get("draftMode"))
            conv_id = payload.get("conversationId")
            user_msg = (payload.get("userMessage") or "").strip()
            agent_ids = self._filter_active_agents(payload.get("agentIds") or [], allow_draft=draft_mode)
            logger.info(f"_filter_active_agents called, {agent_ids}")
            # -- validation -------------------------------------------------
            if not conv_id or not user_msg:
                emit("chat_error", {"error": "Missing conversationId or userMessage"})
                return
            miss = self._missing_agents(agent_ids, g.get("authIdentifier"))
            if miss:
                emit("chat_error", {"error": f"Agent(s) not found: {', '.join(miss)}"})
                return

            # -- draft-mode path (legacy full-rewrite) ----------------------
            if draft_mode:
                self._stream_draft(conv_id, agent_ids, user_msg, emit, cancel_event, trace)
                return

            # -------- non-draft incremental path ---------------------------
            conv = self.store.get_conversation(conv_id) or {}
            if not conv:
                self.store.ensure_conversation_exists(conv_id, agent_ids)
                conv = self._blank_conversation(conv_id, agent_ids)

            old_agent_ids = list(conv.get("agentIds", []))

            # user msg
            user_msg_dict = {
                "id": str(uuid.uuid4()),
                "role": "user",
                "content": user_msg,
                "eventLog": [],
            }
            conv["messages"].append(user_msg_dict)

            # Select based on settings
            sel_agents, justification = select_agents(self.store, agent_ids, conv["messages"], trace=trace)
            sel_ids = [a["agentId"] for a in sel_agents]
            events: list[dict] = []
            if justification and not agents_as_tools():
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": justification, "selection": sel_agents},
                }
                emit("chat_event", evt)
                events.append(evt)
            elif len(sel_agents):
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": "User selected", "selection": sel_agents},
                }
                # emit("chat_event", evt)
                events.append(evt)
            # prompt
            history = self._build_history(agent_ids, conv["messages"])

            # assistant streaming
            final_reply, artifacts_meta = self._stream_reply(
                conv_id, sel_agents, history, emit, events, cancel_event, trace
            )
            # TODO maybe could optimize and return these directly instead of recomputing them
            actions = build_actions(events)
            used_ids = get_used_agent_ids(events)
            assistant_msg_dict = {
                "id": str(uuid.uuid4()),
                "role": "assistant",
                "content": final_reply,
                "eventLog": events,
                "actions": actions,
                "artifactsMetadata": artifacts_meta,
                "selectedAgentIds": sel_ids,
                "usedAgentIds": used_ids,
            }
            conv["messages"].append(assistant_msg_dict)

            # atomic write
            self.store.append_messages(conv_id, [user_msg_dict, assistant_msg_dict])
            # meta update if needed
            title_generated = False
            if conv["title"] is None:
                conv["title"] = self._generate_title(conv["messages"], trace=trace)
                title_generated = True
            if title_generated or old_agent_ids != agent_ids:
                self.store.update_conversation_meta(conv_id, title=conv["title"], agent_ids=agent_ids)
            if not cancel_event.is_set():
                emit(
                    "chat_end",
                    {
                        "agentIds": agent_ids,
                        "conversationId": conv_id,
                        "title": conv["title"],
                        "messages": conv["messages"],
                        "hasEventLog": bool(events),
                    },
                )
        finally:
            # End tracing
            trace.end(int(time.time() * 1000))

            # If we have an assistant message, update it with the trace info
            if assistant_msg_dict is not None:
                # Add the trace data to the assistant message dictionary.
                assistant_msg_dict["trace"] = trace.to_dict() if trace else {}
                # Update the message in the store using the update_message function.
                self.store.update_message(assistant_msg_dict["id"], {"trace": assistant_msg_dict["trace"]})

    # ------------------------------------------------------------------ #
    #  Draft-mode helper (legacy logic, no optimisation)
    # ------------------------------------------------------------------ #
    def _stream_draft(
        self, conv_id: str, agent_ids: list[str], user_msg: str, emit, cancel_event: threading.Event, trace
    ):
        """
        Keeps the original behaviour for draft chats: full conversation
        object rewritten on every turn.
        """
        logger.debug("_stream_draft: conv=%s agents=%s user_msg=%r", conv_id, agent_ids, user_msg)
        conv = self.store.get_draft_conversation(agent_ids[0]) or self._blank_conversation(conv_id, agent_ids)
        user_msg_id = str(uuid.uuid4())
        conv["messages"].append({"id": user_msg_id, "role": "user", "content": user_msg, "hasEventLog": False})

        history = self._build_history(agent_ids, conv["messages"])

        events: list[dict] = []
        final_reply, artifacts_meta = self._stream_reply(
            conv_id,
            [
                {"agentId": aid, "query": user_msg, "agentName": self._user_agents_dict.get(str(aid), {}).get("name")}
                for aid in agent_ids
            ],
            history,
            emit,
            events,
            cancel_event,
            trace,
        )
        # TODO maybe could optimize and return these directly instead of recomputing them
        actions = build_actions(events)
        used_ids = get_used_agent_ids(events)
        conv["messages"].append(
            {
                "id": str(uuid.uuid4()),
                "role": "assistant",
                "content": final_reply,
                "eventLog": events,
                "actions": actions,
                "artifactsMetadata": artifacts_meta,
                "selectedAgentIds": agent_ids,
                "usedAgentIds": used_ids,
                "trace": trace.to_dict() if trace else {},
            }
        )
        conv["lastUpdated"] = datetime.datetime.utcnow().isoformat()
        self.store.upsert_draft_conversation(agent_ids[0], conv)

        if not cancel_event.is_set():
            emit(
                "chat_end",
                {
                    "agentIds": agent_ids,
                    "conversationId": conv_id,
                    "title": conv["title"],
                    "messages": conv["messages"],
                    "hasEventLog": bool(events),
                },
            )

    # ---------- streaming low-level -----------------------------------
    def _stream_reply(self, convId, sel_agents, history, emit, events, cancel_event: threading.Event, trace):
        final = ""
        artifacts_meta = {}

        def token_cb(tok):
            nonlocal final
            final += tok
            emit("chat_token", {"token": tok, "conversationId": convId})

        def push_event(ev_dict: dict, store_event: bool = True):
            ev_dict.update({"conversationId": convId})
            if store_event:
                events.append(ev_dict)
            emit("chat_event", ev_dict)

        sel_ids = [a.get("agentId") for a in sel_agents]
        logger.info(f"_stream_reply  called conv: {convId}, sel_ids: {sel_ids}")
        if len(sel_ids) == 0:
            final = self._stream_base_llm(history, token_cb, push_event, cancel_event)
        elif len(sel_ids) == 1:
            final = self._stream_single(
                sel_agents[0], history, token_cb, push_event, artifacts_meta, cancel_event, convId, trace
            )
        else:
            if agents_as_tools():
                final = self._stream_agent_connect(
                    sel_ids, history, token_cb, push_event, artifacts_meta, cancel_event, convId, trace
                )
            else:
                context = get_agent_context(g.get("authIdentifier"), convId)
                final = OrchestratorService.stream_multiple_agents(
                    llm_id=get_user_base_llm(self.store),
                    sel_agents=sel_agents,
                    messages=history,
                    context=context,
                    tcb=token_cb,
                    pcb=push_event,
                    artifacts_meta=artifacts_meta,
                    cancel_event=cancel_event,
                    store=self.store,
                )
        return final, artifacts_meta

    # ---- concrete streamers -----------------------------------------
    def _stream_base_llm(self, msgs, tcb, pcb, cancel_event: threading.Event):
        final = ""
        comp = self.project.get_llm(get_user_base_llm(self.store)).new_completion()
        for m in msgs:
            comp = comp.with_message(m["content"], role=m["role"])
        logger.info("Streaming using plain LLM - prompt messages: %s", msgs)
        # wrap in closing(...) so that .close() (and thus HTTP teardown) happens on break
        with closing(comp.execute_streamed()) as stream:
            for chunk in stream:
                if cancel_event.is_set():
                    break
                data = chunk.data
                if "text" in data:
                    tcb(data["text"])
                elif data.get("type") == "event":
                    pcb({"eventKind": data["eventKind"], "eventData": data.get("eventData", {})})
                if "text" in data:
                    final += data["text"]
        return final

    def _stream_single(self, agent_info, msgs, tcb, pcb, artifacts_meta, cancel_event: threading.Event, conv_id, trace):
        final = ""
        aid = agent_info.get("agentId")
        aname = agent_info.get("agentName")
        ua = self._user_agents_dict.get(str(aid))
        user = g.get("authIdentifier")
        agent_query = agent_info.get("query", "")
        if ua:
            logger.info(f"Calling user agent {aid} {aname}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            is_owner = ua.get("owner") == user

            # Same simple logic
            if self.draft_mode and is_owner:
                use_published = False
            else:
                if not ua.get("published_version"):
                    tcb("This agent has not been published yet.")
                    return "This agent has not been published yet."
                use_published = True

            ua_obj = UserAgent(ua, use_published=use_published)
            from dataiku.llm.tracing import new_trace

            async def _run():
                nonlocal final
                stream_id = str(uuid.uuid4())
                # mark usage up-front (we know exactly which agent is called)
                async for ev in ua_obj.aprocess_stream(
                    query={"messages": msgs, "context": get_agent_context(user, conv_id)},
                    settings={},
                    trace=trace or new_trace(aid),
                ):
                    if cancel_event.is_set():
                        break
                    normalise_stream_event(
                        ev=ev,
                        tcb=tcb,
                        pcb=pcb,
                        msgs=msgs,
                        aid=aid,
                        trace=trace,
                        aname=aname,
                        query=agent_query,
                        artifacts_meta=artifacts_meta,
                        stream_id=stream_id,
                    )
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]

            asyncio.run(_run())
        else:
            logger.info(f"Calling Enterprise Agent: {aid}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            stream_id = str(uuid.uuid4())
            for ev in call_dss_agent_full_conversation(aid, msgs, user, True, conv_id=conv_id, trace=trace):
                if cancel_event.is_set():
                    break
                normalise_stream_event(
                    ev=ev,
                    tcb=tcb,
                    pcb=pcb,
                    msgs=msgs,
                    aid=aid,
                    trace=trace,
                    aname=aname,
                    query=agent_query,
                    artifacts_meta=artifacts_meta,
                    stream_id=stream_id,
                )
                if "chunk" in ev and "text" in ev["chunk"]:
                    final += ev["chunk"]["text"]
        return final

    def _stream_agent_connect(
        self, aids, msgs, tcb, pcb, artifacts_meta, cancel_event: threading.Event, conv_id, trace
    ):
        final = ""
        from dataiku.llm.tracing import new_trace

        agent_connect = build_agent_connect(
            self.store, aids, user_agents=self._user_agents_dict, draft_mode=self.draft_mode, conv_id=conv_id
        )

        async def _run():
            nonlocal final
            nonlocal trace
            user = g.get("authIdentifier")
            if not trace:
                trace = new_trace("DKU_AGENT_HUB_QUERY")
                trace.begin(int(time.time() * 1000))
            stream_id = str(uuid.uuid4())
            try:
                async for ev in agent_connect.aprocess_stream(
                    query={"messages": msgs, "context": get_agent_context(user, conv_id)},
                    settings={},
                    artifacts_meta=artifacts_meta,
                    trace=trace,
                    pcb=pcb,
                ):
                    if cancel_event.is_set():
                        break
                    normalise_stream_event(ev=ev, tcb=tcb, pcb=pcb, msgs=msgs, stream_id=stream_id)
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]
            finally:
                trace.end(int(time.time() * 1000))

        asyncio.run(_run())
        return final

    # ------------------------------------------------------------------ #
    #  Utility helpers
    # ------------------------------------------------------------------ #
    def _filter_active_agents(self, agent_ids: list[str], allow_draft: bool = False) -> list[str]:
        """Filter to only usable agents based on context."""
        current_user = g.get("authIdentifier")
        user_agents = self._user_agents_dict
        enterprise_ids = {e["id"] for e in get_enterprise_agents(current_user)}

        filtered = []
        for aid in agent_ids:
            if aid in enterprise_ids:
                # Enterprise agents always OK
                filtered.append(aid)
            elif aid in user_agents:
                agent = user_agents[aid]
                is_owner = agent.get("owner") == current_user

                if allow_draft and is_owner:
                    # Owner in draft mode - allow
                    filtered.append(aid)
                elif agent.get("published_version"):
                    # Has published version - allow
                    filtered.append(aid)
                # else: skip - no published version

        return filtered

    @staticmethod
    def _clean_title(raw: str) -> str:
        s = (raw or "").strip()
        s = re.sub(r"```[\w\-]*\s*([\s\S]*?)\s*```", r"\1", s).strip()

        for line in s.splitlines():
            if line.strip():
                s = line.strip()
                break
        else:
            s = ""

        s = re.sub(r"^\s*#{1,6}\s*", "", s)  # '#', '##', etc.
        s = re.sub(r"^\s*>+\s*", "", s)  # '>' or '>>'

        s = re.sub(
            r"^\s*(?:title|subject|chat|conversation)\s*[:：\-–]\s*",
            "",
            s,
            flags=re.IGNORECASE,
        )

        s = s.strip(" \"'`*_")

        s = re.sub(r"\s+", " ", s)

        s = re.sub(r"[.:;!?，。；！]+$", "", s)

        return s or "Untitled"
