# File: backend/ws_utils.py

import threading

import dataiku
from backend.database.user_store import UserStore
from backend.services.conversation_service import ConversationError, ConversationService
from backend.utils.logging_utils import get_logger, set_request_id
from backend.utils.utils import get_store as _get_store
from flask import current_app, g, request as flask_request
from flask_socketio import disconnect, emit, join_room, leave_room

logger = get_logger(__name__)

# ─── Cancellation registry ───────────────────────────────────────────
# Map conversationId → threading.Event so we can signal “stop now”
_cancel_events: dict[str, threading.Event] = {}

# ─── Admin settings presence tracking ─────────────────────────────────
# Track active admin settings editors: socket_id → {user_id, joined_at}
_admin_settings_editors: dict[str, dict] = {}
ADMIN_SETTINGS_ROOM = "admin_settings_editors"


def get_conversation_store():
    """
    Returns the user-scoped store.  Because we wrap it in before_request(),
    g.store will already be a UserStore tied to g.authIdentifier.
    """
    return _get_store()


def _broadcast_admin_presence(socketio):
    """
    Broadcast the current list of admin settings editors to all users in the room.
    """
    from datetime import datetime

    editors = [
        {
            "userId": editor_info["user_id"],
            "joinedAt": editor_info["joined_at"],
        }
        for editor_info in _admin_settings_editors.values()
    ]

    socketio.emit("admin:presence_updated", {"editors": editors}, room=ADMIN_SETTINGS_ROOM)
    logger.debug(f"[admin presence] broadcasted to {len(editors)} editor(s)")


def setup_socketio(socketio):
    @socketio.on("connect")
    def handle_connect():
        before_request()
        logger.info(f"socketio: joining room user:{g.authIdentifier}")
        join_room(f"user:{g.authIdentifier}")
        logger.info("socketio: client connected")

    @socketio.on("message")
    def handle_message(msg):
        before_request()
        logger.info("socketio: received message: %s", msg)
        emit("message", {"echo": msg})

    @socketio.on("disconnect")
    def handle_disconnect():
        before_request()
        sid = flask_request.sid
        logger.info("socketio: client disconnected")

        # Clean up admin settings presence if user was editing
        if sid in _admin_settings_editors:
            user_id = _admin_settings_editors[sid]["user_id"]
            del _admin_settings_editors[sid]
            logger.info(f"[admin presence] removed {user_id} on disconnect")
            _broadcast_admin_presence(socketio)

        disconnect()

    @socketio.on("connect_error")
    def handle_error():
        logger.error("socketio: failed to connect to client")

    @socketio.on("stream_message")
    def handle_stream_message(payload):
        """
        Expects payload:
          {
            "agentIds": [list of agent IDs],
            "conversationId": str,
            "userMessage": str
            "draftMode": bool
          }
        Streams tokens / events back and then appends the assistant reply.
        """
        import time

        start_time = time.time()

        cid = payload.get("conversationId")
        draft_mode = payload.get("draftMode", False)
        user_message = payload.get("userMessage", "")[:50]  # First 50 chars for logging

        logger.info(f"(NEW SOCKET EVENT) stream_message | conversationId: {cid} | message: '{user_message}...'")

        # register a cancel-event for this conversation
        cancel_ev = threading.Event()
        _cancel_events[cid] = cancel_ev

        before_request()
        store = get_conversation_store()
        service = ConversationService(store, draft_mode=draft_mode)
        user_room = f"user:{g.authIdentifier}"

        def emit_to_user(event, data):
            """
            Wrap the emit() function to ensure we send to the right user room.
            """
            socketio.emit(event, data, room=user_room)

        try:
            service.stream_message(payload, emit_to_user, cancel_ev)
            execution_time = time.time() - start_time
            logger.info(
                f"(END SOCKET EVENT) stream_message | conversationId: {cid} | status: success | time: {execution_time:.3f}s"
            )
        except ConversationError as ce:
            execution_time = time.time() - start_time
            logger.warning(
                f"(END SOCKET EVENT) stream_message | conversationId: {cid} | status: error | time: {execution_time:.3f}s | {ce}"
            )
            emit_to_user("chat_error", {"error": str(ce)})
        except Exception as e:
            execution_time = time.time() - start_time
            logger.exception(
                f"(END SOCKET EVENT) stream_message | conversationId: {cid} | status: fatal error | time: {execution_time:.3f}s"
            )
            emit_to_user("chat_error", {"error": "Internal server error"})
        finally:
            # cleanup our flag so we don’t leak memory
            _cancel_events.pop(cid, None)

    @socketio.on("cancel_stream")
    def handle_cancel_stream(payload):
        """
        Client asks us to abort an in-flight stream for conversationId.
        """
        cid = payload.get("conversationId")
        ev = _cancel_events.get(cid)
        logger.info(f"[SOCKET EVENT] cancel_stream | conversationId: {cid}")
        if ev:
            ev.set()

    @socketio.on("join_admin_settings")
    def handle_join_admin_settings():
        """
        Client joins the admin settings presence tracking room.
        Track the user and broadcast updated presence list.
        """
        from datetime import datetime

        before_request()
        user_id = g.authIdentifier
        sid = flask_request.sid

        if not user_id:
            logger.warning("[admin presence] join attempted without valid user_id")
            return

        # Add user to tracking dict
        _admin_settings_editors[sid] = {
            "user_id": user_id,
            "joined_at": datetime.utcnow().isoformat() + "Z",
        }

        # Join the Socket.IO room
        join_room(ADMIN_SETTINGS_ROOM)

        logger.info(f"[admin presence] {user_id} joined (sid={sid})")

        # Broadcast updated presence to all users in room
        _broadcast_admin_presence(socketio)

    @socketio.on("leave_admin_settings")
    def handle_leave_admin_settings():
        """
        Client leaves the admin settings presence tracking room.
        Remove the user and broadcast updated presence list.
        """
        before_request()
        sid = flask_request.sid

        if sid in _admin_settings_editors:
            user_id = _admin_settings_editors[sid]["user_id"]
            del _admin_settings_editors[sid]
            logger.info(f"[admin presence] {user_id} left (sid={sid})")

            # Leave the Socket.IO room
            leave_room(ADMIN_SETTINGS_ROOM)

            # Broadcast updated presence to remaining users
            _broadcast_admin_presence(socketio)
        else:
            logger.debug(f"[admin presence] leave_admin_settings called but sid={sid} not tracked")


def before_request():
    """
    Extract authIdentifier into g, then wrap the base store
    as a UserStore on g.store.

    """
    try:
        # ensure req-id exists before *any* logs are emitted
        set_request_id()
        headers = dict(flask_request.headers)
        auth = dataiku.api_client().get_auth_info_from_browser_headers(headers)
        g.authIdentifier = auth["authIdentifier"]
        g.userAgent = headers.get("User-Agent", "")
        auth_ctx = dataiku.api_client().get_auth_info()
        run_as_login = auth_ctx["authIdentifier"]
        logger.info(f"Backend running as user: {run_as_login}")
        g.run_as_login = run_as_login

        try:
            g.userGroups = dataiku.api_client().get_user(g.authIdentifier).get_settings().get_raw().get("groups", [])
        except Exception:
            g.userGroups = []
    except Exception:
        # logger.warning("Authentication details extraction failed")
        g.authIdentifier = None
        g.userGroups = []

    g.store = UserStore(g.authIdentifier, g.userGroups)
