from typing import List

from sqlalchemy import and_, or_

from backend import schemas
from backend.database.base import db
from backend.database.models import Agent, AgentShare, PrincipalTypeEnum


def get_agent_shares(agent_id: str) -> List[schemas.AgentShareRead] | None:
    as_db = db.session.query(AgentShare).filter(AgentShare.agent_id == agent_id).all()
    if as_db:
        return [schemas.AgentShareRead.model_validate(ag) for ag in as_db]
    return None


def get_agents_shared_with(user_id: str, groups: list[str]) -> list[dict]:
    """Get agents shared with user or their groups, excluding agents they own."""
    if not user_id and not groups:
        return []

    # Base query joining agents and their shares
    query = db.session.query(Agent).join(AgentShare, Agent.id == AgentShare.agent_id)

    # Dynamically build filter conditions
    filters = []
    share_conditions = []

    if user_id:
        share_conditions.append(
            and_(AgentShare.principal_type == PrincipalTypeEnum.USER, AgentShare.principal == user_id)
        )
    if groups:
        share_conditions.append(
            and_(AgentShare.principal_type == PrincipalTypeEnum.GROUP, AgentShare.principal.in_(groups))
        )

    if share_conditions:
        filters.append(or_(*share_conditions))

    # Exclude agents owned by the user
    if user_id:
        filters.append(Agent.owner != user_id)

    # Apply all filters at once
    if filters:
        query = query.filter(and_(*filters))

    agents = query.distinct().all()
    return [schemas.AgentRead.model_validate(agent).model_dump() for agent in agents]


def create_agent_share(agent_share: schemas.AgentShareCreate) -> schemas.AgentShareRead:
    new_agent_share = AgentShare(**agent_share.model_dump())
    db.session.add(new_agent_share)
    db.session.commit()
    db.session.refresh(new_agent_share)
    return schemas.AgentShareRead.model_validate(new_agent_share)


def replace_agent_shares(agent_id: str, shares: List[schemas.AgentShareBase]):
    # Delete all existing shares for the agent
    db.session.query(AgentShare).filter(AgentShare.agent_id == agent_id).delete()
    # Add new shares if provided
    for share in shares:
        new_share = AgentShare(agent_id=agent_id, principal=share.principal, principal_type=share.principal_type)
        db.session.add(new_share)
    db.session.commit()


def get_share_counts(agent_ids: List[str]) -> dict:
    """Return a dict mapping agent_id to number of shares."""
    if not agent_ids:
        return {}
    result = (
        db.session.query(AgentShare.agent_id, db.func.count().label("cnt"))
        .filter(AgentShare.agent_id.in_(agent_ids))
        .group_by(AgentShare.agent_id)
        .all()
    )
    return dict(result)

