import json
import zlib
from typing import Optional

from backend import schemas
from backend.database.base import db
from backend.database.models import Conversation, Message, MessageAgent
from backend.utils.logging_utils import get_logger

logger = get_logger(__name__)


def _get_db_message(message_id: str, user_id: str):
    if not user_id:
        return None
    msg = (
        db.session.query(Message)
        .join(Conversation, Message.conversation_id == Conversation.conversation_id)
        .filter(Message.id == message_id, Conversation.user_id == user_id, Message.status == "active")
        .first()
    )
    return msg


def _get_message_attachments(message_id: str) -> schemas.MessageAttachment | None:
    """
    Helper to fetch attachments for a message by id, no user scoping.
    Matches store.py _get_message_attachments.
    """
    from backend.database.models import MessageAttachment

    result = db.session.query(MessageAttachment).filter(MessageAttachment.message_id == message_id).first()
    if result:
        # TODO why do we need to serialize if the column is jsonendoded already?
        return schemas.MessageAttachment(
            attachments=json.loads(result.attachments) if isinstance(result.attachments, str) else result.attachments,
            extraction_mode=result.extraction_mode.get("mode") if result.extraction_mode else None,
            quota_exceeded=result.extraction_mode.get("quotaExceeded", False) if result.extraction_mode else False,
        )
    return None


def build_message(msg: Message) -> None:
    # Decompress artifacts
    artifacts = {}
    if msg.artifacts:
        try:
            artifacts = json.loads(zlib.decompress(msg.artifacts).decode("utf-8"))
            logger.info(f"artifacts : {artifacts}")
        except Exception as e:
            logger.exception("Error decompressing artifacts for message id: %s", e)
            artifacts = {}
    # Build the schema object, including computed fields
    message_data = {
        "id": msg.id,
        "role": msg.role,
        "content": msg.content,
        "event_log": msg.event_log,
        "has_event_log": bool(msg.event_log),
        "actions": msg.actions or {},
        "artifacts_metadata": artifacts,
        "merged_context": msg.merged_context or {},
        "selected_agent_ids": msg.selected_agent_ids,
        "used_agent_ids": msg.used_agent_ids,
        "llm_id": msg.llm_id,
        "agents_enabled": msg.agents_enabled,
        "status": msg.status,
        "created_at": msg.created_at,
        "tool_validation_requests": msg.tool_validation_requests,
        "tool_validation_responses": msg.tool_validation_responses,
        "memory_fragment": msg.memory_fragment,
    }

    # Attachments
    msg_attachements = _get_message_attachments(msg.id)
    if msg_attachements:
        message_data["attachments"] = msg_attachements.attachments
        message_data["extraction_mode"] = msg_attachements.extraction_mode
        message_data["quota_exceeded"] = msg_attachements.quota_exceeded
    # Feedback
    fb = {}
    if msg.feedback_rating is not None:
        fb = {
            "rating": int(msg.feedback_rating),
            "by": msg.feedback_by,
            "updatedAt": msg.feedback_updated_at,
        }
        if msg.feedback_text:
            fb["text"] = msg.feedback_text
    message_data["feedback"] = fb
    return schemas.MessageRead(**message_data)


def get_message(message_id: str, user_id: str) -> Optional[schemas.MessageRead]:
    msg = _get_db_message(message_id, user_id)
    if not msg:
        return None

    build_message(msg)

    return build_message(msg)


def create_message(message: schemas.MessageCreate) -> Message:
    message_dict = message.model_dump(exclude_unset=True)

    # Handle artifacts
    artifacts = message_dict.pop("artifacts", None)
    if artifacts:
        compressed_artifacts = zlib.compress(json.dumps(artifacts).encode("utf-8"))
        message_dict["artifacts"] = compressed_artifacts
    # Handle trace
    if "trace" in message_dict:
        trace = message_dict.pop("trace", None)
        if trace:
            compressed_trace = zlib.compress(json.dumps(trace).encode("utf-8"))
            message_dict["trace"] = compressed_trace
    new_message = Message(**message_dict)

    db.session.add(new_message)
    # Update parent conversation to trigger auto-update of timestamps
    from backend.database.models import Conversation

    db.session.query(Conversation).filter(Conversation.conversation_id == message.conversation_id).update(
        {"conversation_id": message.conversation_id}
    )

    db.session.commit()
    db.session.refresh(new_message)
    return new_message


def update_message(message_id: str, user_id: str, message_update: schemas.MessageUpdate) -> Message:
    db_message = _get_db_message(message_id, user_id)
    if not db_message:
        return None

    update_data = message_update.model_dump(exclude_unset=True)
    # Handle artifacts
    if "artifacts" in update_data:
        artifacts = update_data["artifacts"]
        if artifacts:
            compressed_artifacts = zlib.compress(json.dumps(artifacts).encode("utf-8"))
            update_data["artifacts"] = compressed_artifacts
        else:
            update_data["artifacts"] = None

    if "trace" in update_data:
        trace = update_data["trace"]
        if trace:
            compressed_trace = zlib.compress(json.dumps(trace).encode("utf-8"))
            update_data["trace"] = compressed_trace
        else:
            update_data["trace"] = None
    for key, value in update_data.items():
        setattr(db_message, key, value)

    db.session.commit()
    db.session.refresh(db_message)
    return db_message


def delete_message(message_id: str, user_id: str):
    db_message = _get_db_message(message_id, user_id)
    if db_message:
        db.session.delete(db_message)
        db.session.commit()
        return True
    return False


def update_message_feedback(message_id: str, user_id: str, rating: int | None, text: str | None) -> None:
    """
    Update feedback for a message, enforcing user ownership.
    """
    msg = _get_db_message(message_id=message_id, user_id=user_id)
    if not msg:
        raise PermissionError("Message not found or not owned by user")
    msg.feedback_rating = rating if rating in (0, 1) else None
    msg.feedback_text = text or None
    msg.feedback_by = user_id
    # feedback_updated_at will be auto-updated by SQLAlchemy

    db.session.commit()


def clear_message_feedback(message_id: str, user_id: str) -> None:
    """
    Clear feedback fields for a message, enforcing user ownership.
    """
    msg = _get_db_message(message_id=message_id, user_id=user_id)
    if not msg:
        raise PermissionError("Message not found or not owned by user")
    msg.feedback_rating = None
    msg.feedback_text = None
    msg.feedback_by = None
    # feedback_updated_at will be auto-updated by SQLAlchemy
    db.session.commit()


def append_messages(conv_id: str, messages: list[schemas.MessageCreate]) -> None:
    """
    Append multiple messages to a conversation using Pydantic schemas.
    Each message dict should contain all required fields except conversation_id.
    """
    new_messages = []
    agent_links = []
    for msg in messages:
        # Convert dict to Pydantic schema
        # msg_schema = schemas.MessageCreate(**msg)
        msg_data = msg.model_dump(exclude_unset=True)
        # Handle artifacts compression if present
        artifacts = msg_data.pop("artifacts", None)
        if artifacts is not None:
            msg_data["artifacts"] = zlib.compress(json.dumps(artifacts).encode("utf-8"))
        # Handle trace compression if present
        trace = msg_data.pop("trace", None)
        if trace is not None:
            msg_data["trace"] = zlib.compress(json.dumps(trace).encode("utf-8"))
        msg_data.pop("attachments", None)  # attachments are stored separately
        new_message = Message(**msg_data)
        new_messages.append(new_message)
        # Prepare message_agents rows
        selected_ids = msg_data.get("selected_agent_ids") or []
        used_ids = msg_data.get("used_agent_ids") or []
        logger.info("Selected IDs: %s, Used IDs: %s", selected_ids, used_ids)
        seen = set(selected_ids) | set(used_ids)
        for aid in seen:
            agent_links.append(
                MessageAgent(
                    message_id=msg_data["id"],
                    agent_id=aid,
                    selected=1 if aid in set(selected_ids) else 0,
                    used=1 if aid in set(used_ids) else 0,
                )
            )

    db.session.add_all(new_messages)
    db.session.add_all(agent_links)
    # Update parent conversation to trigger auto-update of timestamps
    db.session.query(Conversation).filter(Conversation.conversation_id == conv_id).update(
        {
            "conversation_id": conv_id,
        }
    )
    db.session.commit()


def get_message_events(message_id: str, user_id: str | None = None) -> list:
    """
    Return the event log for a message, enforcing user ownership if user_id is provided.
    """

    query = db.session.query(Message.event_log)
    if user_id:
        query = query.join(Conversation, Message.conversation_id == Conversation.conversation_id)
        query = query.filter(Message.id == message_id, Conversation.user_id == user_id, Message.status == "active")
        authorized = query.first()
        if not authorized:
            return []
        event_log = authorized[0]
    else:
        result = db.session.query(Message.event_log).filter(Message.id == message_id).first()
        if not result:
            return []
        event_log = result[0]
    return event_log if event_log else []


def get_message_trace(message_id: str, user_id: str | None = None) -> str:
    """
    Return the trace string for a message, enforcing user ownership if user_id is provided.
    """
    msg = _get_db_message(message_id, user_id)
    if not msg or not msg.trace:
        return ""
    try:
        trace = json.loads(zlib.decompress(msg.trace).decode("utf-8"))
    except Exception as e:
        logger.exception("Error decompressing trace for message id: %s", e)
        trace = ""
    return trace


# Document uploads
def insert_or_update_message_attachments(
    message_id: str,
    attachments: list[dict],
    extraction_mode: Optional[dict] = None,
) -> None:
    """
    Insert or update message attachments metadata for a given message_id.
    """
    from backend.database.models import MessageAttachment

    # attachments = schemas.MessageAttachmentList.model_validate(attachments) if hasattr(schemas, 'MessageAttachmentList') else attachments
    obj = db.session.query(MessageAttachment).filter_by(message_id=message_id).first()
    if obj:
        obj.attachments = attachments
        # Always store as JSON (only mode and quota_exceeded present)
        obj.extraction_mode = extraction_mode
    else:
        obj = MessageAttachment(message_id=message_id, attachments=attachments)
        db.session.add(obj)
    db.session.commit()
