import pickle
import re

from backend.config import get_enterprise_agents
from backend.models.artifacts import Artifact, ArtifactsMetadata
from backend.utils.logging_utils import get_logger

logger = get_logger(__name__)

EVENT_DATA = "eventData"
EVENT_KIND = "eventKind"


def get_references(events: list[dict]) -> list:
    references = []
    for event in events or []:
        if event.get(EVENT_KIND) == "references":
            references.append(event[EVENT_DATA])
    return references


def get_used_tables(reference: dict) -> list[str]:
    TABLE_NAME_PATTERN = re.compile(r"""['"]table_name['"]\s*:\s*['"]([^'"]+)['"]""")
    tables: set[str] = set()
    sources = reference.get("sources") or []
    for source in sources:
        for item in source.get("items") or []:
            if (
                item.get("type") == "INFO"
                and isinstance(item.get("textSnippet"), str)
                and item["textSnippet"].startswith("Decided to use [")
            ):
                for match in TABLE_NAME_PATTERN.finditer(item["textSnippet"]):
                    tables.add(match.group(1))

    return list(tables)


def get_used_agent_ids(events: list[dict]) -> list[str]:
    """
    Extract stable agent IDs from explicit AGENT_USED events produced by
    UserAgent (single) or AgentConnect (multi). No parsing of tool names.
    """

    # TODO : IMPLEMENT ME
    used: set[str] = set()

    return list(used)


def get_chart_plans(events: list[dict]) -> list[dict]:
    plans = []
    # look for chart data if there is any
    for event in events or []:
        if event[EVENT_KIND] == "chart_plan" and "chart_plan" in event[EVENT_DATA]:
            plans.append(event[EVENT_DATA])
    return plans


def extract_artifacts_preview(artifacts: list[Artifact]):
    preview = []
    for artifact in artifacts or []:
        artifact_preview = {
            "name": artifact.get("name"),
            "type": artifact.get("type"),
            "parts": [],
            "description": artifact.get("description"),
            "preview": True,
        }
        for item in artifact.get("parts") or []:
            if item.get("type") != "RECORDS":
                artifact_preview["parts"].append(item)
            else:
                artifact_preview["parts"].append(
                    {
                        "type": "RECORDS",
                        "records": {
                            "columns": item["records"]["columns"],
                            "data": item["records"]["data"][:50],  # Preview only first 5 rows
                        },
                    }
                )
        preview.append(artifact_preview)
    return preview


def get_artifacts_metadata(artifacts: dict, max_size_mb: float) -> dict:
    if not artifacts:
        return {}
    logger.info(f"Getting artifacts metadata for artifacts, max size {max_size_mb} MB")
    meta = {}
    for k, v in artifacts.items():
        size = v.get("size_mb", get_artifacts_size_mb(v.get("artifacts")))
        meta[k] = ArtifactsMetadata(
            size_mb=v.get("size_mb", 0),
            artifacts_id=v.get("artifacts_id", ""),
            agentName=v.get("agentName", ""),
            agentId=v.get("agentId", ""),
            query=v.get("query", ""),
            has_records=v.get("has_records", False),
            artifacts=v.get("artifacts") if size <= max_size_mb else extract_artifacts_preview(v.get("artifacts")),
            preview=True if size > max_size_mb else False,
        )
    return meta


def get_selected_agents(events: list[dict]) -> list[any]:
    if not events:
        return []
    for event in events:
        if event.get(EVENT_KIND) == "AGENT_SELECTION":
            return event.get(EVENT_DATA, {}).get("selection")
    return []


def get_artifacts_size_mb(artifacts: dict) -> float:
    size_bytes = len(pickle.dumps(artifacts))
    return size_bytes / (1024 * 1024)


def has_records(artifacts: dict) -> bool:
    for a in artifacts or []:
        for item in a.get("parts") or []:
            if item.get("type") == "RECORDS":
                return True
    return False


def build_actions(events: list[dict]) -> dict:
    from flask import g

    actions = {}
    if events:
        chart_plans = get_chart_plans(events)
        if chart_plans:
            actions["chart_plans"] = chart_plans
        references = get_references(events)
        sel_agents = get_selected_agents(events)
        logger.info(f"Building message actions, references: {len(references)}")
        logger.info(f"Building message actions, selected agents: {len(sel_agents)}")
        if len(references) != 1 or len(sel_agents) != 1:
            # Only enable stories for one agent and one tool with references
            actions["stories"] = {"enable_stories": False}
        else:
            tables = get_used_tables(references[0])
            logger.info(f"Building message actions, tables: {tables}")
            eas = get_enterprise_agents(g.get("authIdentifier", None))
            agent = next((ea for ea in eas if ea.get("id") == sel_agents[0]["agentId"]), None)
            logger.info(f"Building message actions, agent: {agent}")
            if agent and agent.get("stories_workspace"):
                actions["stories"] = {
                    "enable_stories": True,
                    "tables": tables,
                    "agent_id": sel_agents[0]["agentId"],
                    "stories_workspace": agent.get("stories_workspace"),
                    "agent_short_example_queries": agent.get("agent_short_example_queries", []),
                    "agent_example_queries": agent.get("agent_example_queries", []),
                }
    return actions
