"""
Analytics CRUD functions for agent, conversation, and message usage.
"""
from datetime import datetime, timedelta
from typing import List, Optional

from sqlalchemy import and_, or_

from backend.database.base import db
from backend.database.models import Agent, AgentShare, Conversation, Message, MessageAgent


# --- helpers for analytics date handling  ---
def analytics_shared_users(owner: str) -> int:
    """
    Count DISTINCT user principals for shares on agents owned by `owner`.
    """
    result = (
        db.session.query(AgentShare.principal)
        .join(Agent, Agent.id == AgentShare.agent_id)
        .filter(Agent.owner == owner, AgentShare.principal_type == "user")
        .distinct()
        .count()
    )
    return result


def analytics_shared_users_for_all_user_agents() -> int:
    """
    Count DISTINCT user principals for shares on ANY user agent.
    """
    query = db.session.query(AgentShare.principal).join(Agent, Agent.id == AgentShare.agent_id)
    query = query.filter(AgentShare.principal_type == "user")
    result = query.distinct().count()
    return result


def analytics_shared_agents(owner: str) -> int:
    """
    Number of user agents (owned by `owner`) that are shared (to a user or a group).
    """
    result = (
        db.session.query(AgentShare.agent_id)
        .join(Agent, Agent.id == AgentShare.agent_id)
        .filter(Agent.owner == owner)
        .distinct()
        .count()
    )
    return result


def analytics_shared_agents_for_all_user_agents() -> int:
    """
    Number of user agents (any owner) that are shared (to a user or a group).
    """
    query = db.session.query(AgentShare.agent_id).join(Agent, Agent.id == AgentShare.agent_id)
    result = query.distinct().count()
    return result

def analytics_usage_buckets(
    owner: str,
    start: Optional[str],
    end: Optional[str],
    bucket: str,
    agent_id: Optional[str],
    *,
    is_project_admin: bool = False,
    agent_type: Optional[str] = None,
    owner_id: Optional[str] = None,
    enterprise_ids: Optional[List[str]] = None,
) -> list[dict]:
    """
    Usage buckets for assistant messages invoking owned agent(s), grouped by time bucket.
    """
    # --- Date normalization ---
    def _parse(s: Optional[str]) -> Optional[datetime]:
        if not s:
            return None
        try:
            dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
            return dt.replace(tzinfo=None)  # Convert to naive
        except Exception:
            try:
                return datetime.strptime(s[:19], "%Y-%m-%dT%H:%M:%S")
            except Exception:
                try:
                    return datetime.strptime(s[:10], "%Y-%m-%d")
                except Exception:
                    return None
    end_dt = _parse(end) or datetime.utcnow()
    start_dt = _parse(start) or (end_dt - timedelta(days=56))
    d0 = start_dt.strftime("%Y-%m-%d")
    end_excl_dt = end_dt + timedelta(days=1)
    d1 = end_excl_dt.strftime("%Y-%m-%d")

    # --- Bucket key builder ---
    def key_for(ts: str) -> str:
        if bucket == "day":
            return ts[:10]
        if bucket == "month":
            return ts[:7]
        try:
            dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
        except Exception:
            dt = datetime.strptime(ts[:19], "%Y-%m-%dT%H:%M:%S")
        iso_y, iso_w, _ = dt.isocalendar()
        return f"{iso_y}-W{iso_w:02d}"

    # --- Agent scope clause ---
    clause = []
    if agent_id:
        clause.append(MessageAgent.agent_id == agent_id)
    elif not is_project_admin:
        clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner))
    else:
        agent_type = (agent_type or "all").lower()
        enterprise_ids = list(enterprise_ids or [])
        if agent_type == "enterprise":
            if enterprise_ids:
                clause.append(MessageAgent.agent_id.in_(enterprise_ids))
            else:
                clause.append(False)
        elif agent_type == "user":
            if owner_id:
                clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id))
            else:
                clause.append(Agent.id == MessageAgent.agent_id)
        else:
            # agent_type == "all"
            user_clause = Agent.id == MessageAgent.agent_id
            if owner_id:
                user_clause = and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id)
            if enterprise_ids:
                ent_clause = MessageAgent.agent_id.in_(enterprise_ids)
                clause.append(or_(user_clause, ent_clause))
            else:
                clause.append(user_clause)

    # --- Query ---
    q = (
        db.session.query(Message.id, Message.created_at)
        .join(Conversation, Message.conversation_id == Conversation.conversation_id)
        .join(MessageAgent, MessageAgent.message_id == Message.id)
        .outerjoin(Agent, Agent.id == MessageAgent.agent_id)
        .filter(
            Message.role == "assistant",
            MessageAgent.selected == 1,
            Message.created_at >= d0,
            Message.created_at < d1,
            *clause
        )
        .distinct()
    )
    counts = {}
    for msg_id, created_at in q:
        ts = created_at if isinstance(created_at, str) else created_at.isoformat()
        dt = _parse(ts)
        if dt is None or not (start_dt <= dt < end_excl_dt):
            continue
        k = key_for(ts)
        counts[k] = counts.get(k, 0) + 1
    return [{"periodStart": k, "count": counts[k]} for k in sorted(counts.keys())]

def analytics_feedback_counts(
        owner: str,
        start: Optional[str],
        end: Optional[str],
        agent_id: Optional[str],
        *,
        is_project_admin: bool = False,
        agent_type: Optional[str] = None,
        owner_id: Optional[str] = None,
        enterprise_ids: Optional[List[str]] = None,
    ) -> dict:
        """
        Count feedback ratings (positive, negative, none) for assistant messages.
        """
        def _parse(s: Optional[str]) -> Optional[datetime]:
            if not s:
                return None
            try:
                dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
                return dt.replace(tzinfo=None)  # Convert to naive
            except Exception:
                try:
                    return datetime.strptime(s[:19], "%Y-%m-%dT%H:%M:%S")
                except Exception:
                    try:
                        return datetime.strptime(s[:10], "%Y-%m-%d")
                    except Exception:
                        return None
        end_dt = _parse(end) or datetime.utcnow()
        start_dt = _parse(start) or (end_dt - timedelta(days=56))
        d0 = start_dt.strftime("%Y-%m-%d")
        end_excl_dt = end_dt + timedelta(days=1)
        d1 = end_excl_dt.strftime("%Y-%m-%d")

        clause = []
        if agent_id:
            clause.append(MessageAgent.agent_id == agent_id)
        elif not is_project_admin:
            clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner))
        else:
            agent_type = (agent_type or "all").lower()
            enterprise_ids = list(enterprise_ids or [])
            if agent_type == "enterprise":
                if enterprise_ids:
                    clause.append(MessageAgent.agent_id.in_(enterprise_ids))
                else:
                    clause.append(False)
            elif agent_type == "user":
                if owner_id:
                    clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id))
                else:
                    clause.append(Agent.id == MessageAgent.agent_id)
            else:
                # agent_type == "all"
                user_clause = Agent.id == MessageAgent.agent_id
                if owner_id:
                    user_clause = and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id)
                if enterprise_ids:
                    ent_clause = MessageAgent.agent_id.in_(enterprise_ids)
                    clause.append(or_(user_clause, ent_clause))
                else:
                    clause.append(user_clause)

        q = (
            db.session.query(Message.id, Message.created_at, Message.feedback_rating)
            .join(Conversation, Message.conversation_id == Conversation.conversation_id)
            .join(MessageAgent, MessageAgent.message_id == Message.id)
            .outerjoin(Agent, Agent.id == MessageAgent.agent_id)
            .filter(
                Message.role == "assistant",
                MessageAgent.selected == 1,
                Message.created_at >= d0,
                Message.created_at < d1,
                *clause
            )
            .distinct()
        )
        out = {"positive": 0, "negative": 0, "none": 0}
        for _, created_at, rating in q:
            ts = created_at if isinstance(created_at, str) else created_at.isoformat()
            dt = _parse(ts)
            if dt is None or not (start_dt <= dt < end_excl_dt):
                continue
            if rating is None:
                out["none"] += 1
            else:
                out["positive" if int(rating) == 1 else "negative"] += 1
        return out

def analytics_active_users_buckets(
    owner: str,
    start: Optional[str],
    end: Optional[str],
    bucket: str,
    agent_id: Optional[str],
    *,
    is_project_admin: bool = False,
    agent_type: Optional[str] = None,
    owner_id: Optional[str] = None,
    enterprise_ids: Optional[List[str]] = None,
) -> list[dict]:
    """
    COUNT(DISTINCT user_id) per bucket for assistant messages that invoked an owned agent.
    """

    def _parse(s: Optional[str]) -> Optional[datetime]:
        if not s:
            return None
        try:
            dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
            return dt.replace(tzinfo=None)  # Convert to naive
        except Exception:
            try:
                return datetime.strptime(s[:19], "%Y-%m-%dT%H:%M:%S")
            except Exception:
                try:
                    return datetime.strptime(s[:10], "%Y-%m-%d")
                except Exception:
                    return None
    end_dt = _parse(end) or datetime.utcnow()
    start_dt = _parse(start) or (end_dt - timedelta(days=56))
    d0 = start_dt.strftime("%Y-%m-%d")
    end_excl_dt = end_dt + timedelta(days=1)
    d1 = end_excl_dt.strftime("%Y-%m-%d")

    def key_for(ts: str) -> str:
        if bucket == "day":
            return ts[:10]
        if bucket == "month":
            return ts[:7]
        try:
            dt = datetime.fromisoformat(ts.replace("Z", "+00:00"))
        except Exception:
            dt = datetime.strptime(ts[:19], "%Y-%m-%dT%H:%M:%S")
        iso_y, iso_w, _ = dt.isocalendar()
        return f"{iso_y}-W{iso_w:02d}"

    clause = []
    if agent_id:
        clause.append(MessageAgent.agent_id == agent_id)
    elif not is_project_admin:
        clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner))
    else:
        agent_type = (agent_type or "all").lower()
        enterprise_ids = list(enterprise_ids or [])
        if agent_type == "enterprise":
            if enterprise_ids:
                clause.append(MessageAgent.agent_id.in_(enterprise_ids))
            else:
                clause.append(False)
        elif agent_type == "user":
            if owner_id:
                clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id))
            else:
                clause.append(Agent.id == MessageAgent.agent_id)
        else:
            # agent_type == "all"
            user_clause = Agent.id == MessageAgent.agent_id
            if owner_id:
                user_clause = and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id)
            if enterprise_ids:
                ent_clause = MessageAgent.agent_id.in_(enterprise_ids)
                clause.append(or_(user_clause, ent_clause))
            else:
                clause.append(user_clause)

    q = (
        db.session.query(Message.id, Message.created_at, Conversation.user_id)
        .join(Conversation, Message.conversation_id == Conversation.conversation_id)
        .join(MessageAgent, MessageAgent.message_id == Message.id)
        .outerjoin(Agent, Agent.id == MessageAgent.agent_id)
        .filter(
            Message.role == "assistant",
            MessageAgent.selected == 1,
            Message.created_at >= d0,
            Message.created_at < d1,
            *clause
        )
        .distinct()
    )
    buckets = {}
    for _, created_at, user_id in q:
        ts = created_at if isinstance(created_at, str) else created_at.isoformat()
        dt = _parse(ts)
        if dt is None or not (start_dt <= dt < end_excl_dt):
            continue
        k = key_for(ts)
        s = buckets.setdefault(k, set())
        s.add(user_id)
    return [{"periodStart": k, "activeUsers": len(uids)} for k, uids in sorted(buckets.items())]

def analytics_activity(
        owner: str,
        start: Optional[str],
        end: Optional[str],
        agent_id: Optional[str],
        limit: int,
        offset: int,
        *,
        is_project_admin: bool = False,
        agent_type: Optional[str] = None,
        owner_id: Optional[str] = None,
        enterprise_ids: Optional[List[str]] = None,
        q: Optional[str] = None,
        sort_by: Optional[str] = None,
        sort_dir: Optional[str] = None,
        group_by: Optional[str] = None,
    ) -> tuple[list[dict], int]:
        """
        Activity analytics for assistant messages, grouped and sorted.
        """
        def _parse(s: Optional[str]) -> Optional[datetime]:
            if not s:
                return None
            try:
                dt = datetime.fromisoformat(s.replace("Z", "+00:00"))
                return dt.replace(tzinfo=None)  # Convert to naive
            except Exception:
                try:
                    return datetime.strptime(s[:19], "%Y-%m-%dT%H:%M:%S")
                except Exception:
                    try:
                        return datetime.strptime(s[:10], "%Y-%m-%d")
                    except Exception:
                        return None
        end_dt = _parse(end) or datetime.utcnow()
        start_dt = _parse(start) or (end_dt - timedelta(days=56))
        d0 = start_dt.strftime("%Y-%m-%d")
        end_excl_dt = end_dt + timedelta(days=1)
        d1 = end_excl_dt.strftime("%Y-%m-%d")

        clause = []
        if agent_id:
            clause.append(MessageAgent.agent_id == agent_id)
        elif not is_project_admin:
            clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner))
        else:
            agent_type = (agent_type or "all").lower()
            enterprise_ids = list(enterprise_ids or [])
            if agent_type == "enterprise":
                if enterprise_ids:
                    clause.append(MessageAgent.agent_id.in_(enterprise_ids))
                else:
                    clause.append(False)
            elif agent_type == "user":
                if owner_id:
                    clause.append(and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id))
                else:
                    clause.append(Agent.id == MessageAgent.agent_id)
            else:
                user_clause = Agent.id == MessageAgent.agent_id
                if owner_id:
                    user_clause = and_(Agent.id == MessageAgent.agent_id, Agent.owner == owner_id)
                if enterprise_ids:
                    ent_clause = MessageAgent.agent_id.in_(enterprise_ids)
                    clause.append(or_(user_clause, ent_clause))
                else:
                    clause.append(user_clause)

        # Build agent name map using get_agents_by_owner, which now returns schemas.AgentRead
        from backend.database import crud
        name_map = {a.id: a.name for a in crud.get_agents_by_owner(owner)}

        q_base = (
            db.session.query(
                Message.id,
                Message.created_at,
                Conversation.user_id,
                MessageAgent.agent_id,
                Message.feedback_rating,
            )
            .join(Conversation, Message.conversation_id == Conversation.conversation_id)
            .join(MessageAgent, MessageAgent.message_id == Message.id)
            .outerjoin(Agent, Agent.id == MessageAgent.agent_id)
            .filter(
                Message.role == "assistant",
                MessageAgent.selected == 1,
                Message.created_at >= d0,
                Message.created_at < d1,
                *clause
            )
            .distinct()
        )

        agg = {}
        for msg_id, created_at, user_id, agent_id_val, feedback_rating in q_base:
            ts = created_at if isinstance(created_at, str) else created_at.isoformat()
            dt = _parse(ts)
            if dt is None or not (start_dt <= dt < end_excl_dt):
                continue
            key = (agent_id_val, user_id)
            entry = agg.setdefault(
                key,
                {
                    "agentId": agent_id_val,
                    "agentName": name_map.get(agent_id_val, agent_id_val),
                    "user": user_id,
                    "questions": 0,
                    "rated": 0,
                    "sum": 0.0,
                },
            )
            entry["questions"] += 1
            if feedback_rating is not None:
                entry["rated"] += 1
                entry["sum"] += 1.0 if int(feedback_rating) == 1 else 0.0

        base_rows = list(agg.values())
        # Filter by q
        q_norm = (q or "").strip().lower()
        if q_norm:
            def _contains_detail(r: dict) -> bool:
                return q_norm in (r.get("agentName") or "").lower() or q_norm in (r.get("user") or "").lower()
            base_rows = [r for r in base_rows if _contains_detail(r)]

        # Grouping
        group_by = (group_by or "").lower()
        rows_all = []
        if group_by == "agent":
            by_agent = {}
            for r in base_rows:
                aid = r["agentId"]
                e = by_agent.setdefault(
                    aid,
                    {
                        "agentId": aid,
                        "agentName": r["agentName"],
                        "user": None,
                        "questions": 0,
                        "rated": 0,
                        "sum": 0.0,
                    },
                )
                e["questions"] += r.get("questions", 0)
                e["rated"] += r.get("rated", 0)
                e["sum"] += r.get("sum", 0.0)
            for e in by_agent.values():
                rated = e.pop("rated")
                s = e.pop("sum", 0.0)
                e["avgFeedback"] = (s / rated) if rated else None
                rows_all.append(e)
        elif group_by == "user":
            by_user = {}
            for r in base_rows:
                uid = r["user"]
                e = by_user.setdefault(
                    uid,
                    {
                        "agentId": None,
                        "agentName": None,
                        "user": uid,
                        "questions": 0,
                        "rated": 0,
                        "sum": 0.0,
                    },
                )
                e["questions"] += r.get("questions", 0)
                e["rated"] += r.get("rated", 0)
                e["sum"] += r.get("sum", 0.0)
            for e in by_user.values():
                rated = e.pop("rated")
                s = e.pop("sum", 0.0)
                e["avgFeedback"] = (s / rated) if rated else None
                rows_all.append(e)
        else:
            for r in base_rows:
                rated = r.pop("rated")
                s = r.pop("sum", 0.0)
                r["avgFeedback"] = (s / rated) if rated else None
                rows_all.append(r)

        # Sorting
        sort_by = sort_by or "questions"
        sort_dir = (sort_dir or "desc").lower()
        if sort_by not in {"agentName", "user", "questions", "avgFeedback"}:
            sort_by = "questions"
        reverse = sort_dir == "desc"

        if sort_by == "avgFeedback":
            if sort_dir == "asc":
                rows_all.sort(key=lambda r: (r.get("avgFeedback") is None, r.get("avgFeedback") or 0.0))
            else:
                rows_all.sort(key=lambda r: (r.get("avgFeedback") is None, -(r.get("avgFeedback") or 0.0)))
        elif sort_by == "questions":
            rows_all.sort(key=lambda r: int(r.get("questions", 0)), reverse=reverse)
        else:
            rows_all.sort(key=lambda r: (r.get(sort_by) or "").casefold(), reverse=reverse)

        total = len(rows_all)
        return rows_all[offset : offset + limit], total
