from typing import List

from backend import schemas
from backend.database.base import db
from backend.database.crud.message import build_message
from backend.database.models import Conversation, DerivedDocument, Message, StatusEnum


def get_user_conversation(conversation_id: str, user_id: str) -> schemas.FullConversationRead | None:
    if not conversation_id or not user_id:
        return None
    db_conv = (
        db.session.query(Conversation)
        .filter(
            Conversation.conversation_id == conversation_id,
            Conversation.user_id == user_id,
            Conversation.status == StatusEnum.ACTIVE,
        )
        .first()
    )
    if db_conv:
        conv = schemas.ConversationRead.model_validate(db_conv)
        messages = [build_message(msg) for msg in db_conv.messages]
        return schemas.FullConversationRead(**conv.model_dump(), messages=messages)
    return None


def create_conversation(conversation: schemas.ConversationCreate) -> schemas.FullConversationRead:
    new_conversation = Conversation(**conversation.model_dump())
    db.session.add(new_conversation)
    db.session.commit()
    db.session.refresh(new_conversation)
    conv = schemas.ConversationRead.model_validate(new_conversation)
    messages = [build_message(msg) for msg in new_conversation.messages]
    return schemas.FullConversationRead(**conv.model_dump(), messages=messages)


def get_conversations_by_user(user_id: str, skip: int = 0, limit: int = 100) -> List[schemas.FullConversationRead]:
    db_convs = db.session.query(Conversation).filter(
        Conversation.user_id == user_id, Conversation.status == StatusEnum.ACTIVE
    )
    outputs = []
    for db_conv in db_convs:
        conv = schemas.ConversationRead.model_validate(db_conv)
        messages = [build_message(msg) for msg in db_conv.messages]
        outputs.append(schemas.FullConversationRead(**conv.model_dump(), messages=messages))
    return outputs


def get_conversations_ids_by_user(user_id: str) -> List[str]:
    rows = (
        db.session.query(Conversation.conversation_id)
        .filter(Conversation.user_id == user_id, Conversation.status == StatusEnum.ACTIVE)
        .all()
    )
    return [row[0] for row in rows]


def get_conversations_metadata(user_id: str) -> List[schemas.ConversationMetadata]:
    # TODO: Create a ConversationMetadata schema to avoid loading messages
    db_convs = (
        db.session.query(Conversation)
        .filter(Conversation.user_id == user_id, Conversation.status == StatusEnum.ACTIVE)
        .all()
    )
    return [schemas.ConversationMetadata.model_validate(conv, from_attributes=True) for conv in db_convs]


def update_conversation(
    conversation_id: str, conversation_update: schemas.ConversationUpdate
) -> schemas.ConversationUpdate | None:
    db_conversation = (
        db.session.query(Conversation)
        .filter(Conversation.conversation_id == conversation_id, Conversation.status == StatusEnum.ACTIVE)
        .first()
    )
    if not db_conversation:
        return None
    update_data = conversation_update.model_dump(exclude_unset=True)
    for key, value in update_data.items():
        setattr(db_conversation, key, value)
    db.session.commit()
    db.session.refresh(db_conversation)
    return schemas.ConversationUpdate.model_validate(db_conversation, from_attributes=True)


def update_conversation_meta(conversation_id: str, conv_update: schemas.ConversationMetadataUpdate) -> None:
    """
    Update conversation metadata fields using ConversationUpdate model.
    Only provided fields are updated. Timestamps are handled automatically by the model.
    """
    if conv_update:
        update_data = conv_update.model_dump(exclude_unset=True)
        db.session.query(Conversation).filter(Conversation.conversation_id == conversation_id).update(update_data)
        db.session.commit()


def delete_conversation(conversation_id: str, user_id: str, permanent: bool = False) -> bool:
    db_conversation = (
        db.session.query(Conversation)
        .filter(Conversation.conversation_id == conversation_id, Conversation.user_id == user_id)
        .first()
    )
    if db_conversation is not None:
        if permanent:
            db.session.query(Message).filter(Message.conversation_id == db_conversation.conversation_id).delete(
                synchronize_session=False
            )
            db.session.query(DerivedDocument).filter(DerivedDocument.conv_id == db_conversation.conversation_id).delete(
                synchronize_session=False
            )
            db.session.delete(db_conversation)
        else:
            db_conversation.status = StatusEnum.DELETED
            db.session.query(Message).filter(
                Message.conversation_id == db_conversation.conversation_id, Message.status == StatusEnum.ACTIVE
            ).update({"status": StatusEnum.DELETED}, synchronize_session=False)
        db.session.commit()
        return True
    return False


def delete_all_conversations(user_id: str, permanent: bool = False) -> None:
    # Get all conversation IDs for this user first
    conversation_ids = (
        db.session.query(Conversation.conversation_id)
        .filter(Conversation.user_id == user_id, Conversation.status == StatusEnum.ACTIVE)
        .all()
    )
    conversation_ids = [cid[0] for cid in conversation_ids]

    query = db.session.query(Conversation).filter(Conversation.conversation_id.in_(conversation_ids))
    if permanent:
        # Delete messages first, then derived docs and then conversations
        db.session.query(Message).filter(Message.conversation_id.in_(conversation_ids)).delete(
            synchronize_session=False
        )
        db.session.query(DerivedDocument).filter(DerivedDocument.conv_id.in_(conversation_ids)).delete(
            synchronize_session=False
        )
        count = query.delete(synchronize_session=False)
    else:
        count = query.update({"status": StatusEnum.DELETED}, synchronize_session=False)
        # mark all related messages as deleted
        db.session.query(Message).filter(
            Message.conversation_id.in_(conversation_ids), Message.status == StatusEnum.ACTIVE
        ).update({"status": StatusEnum.DELETED}, synchronize_session=False)
        # TODO: mark all related derived documents and message attachements as deleted ?
    db.session.commit()
    return count
