"""
Trace parsing and transformation logic.
All functions are pure (no side effects) for easy testing.
"""
import copy
import json
from typing import Any, Dict, List, Tuple


def parse_traces(resp_str: str) -> List[Tuple[str, Dict[str, Any]]]:
    """
    Parse trace JSON string into list of (trace_name, trace_data) tuples.

    Args:
        resp_str: JSON string containing trace data

    Returns:
        List of (trace_name, trace_data) tuples

    Raises:
        json.JSONDecodeError: If resp_str is not valid JSON
    """
    data = json.loads(resp_str)

    # Single trace format
    if "trace" in data and isinstance(data["trace"], dict):
        return [("", data["trace"])]

    # Named traces format (handle both camelCase and snake_case)
    named_traces = data.get("named_traces") or data.get("namedTraces")
    if named_traces and isinstance(named_traces, list):
        return [
            (trace.get("trace_name", ""), trace.get("trace_value", {}))
            for trace in named_traces
        ]

    # Default: treat entire data as single unnamed trace
    return [("", data)]


def extract_display_name(trace_data: Dict[str, Any]) -> str:
    """
    Extract display name from trace data (first user message text).

    Args:
        trace_data: Trace data dictionary

    Returns:
        Display name or "?" if not found
    """
    messages = trace_data.get("inputs", {}).get("messages", [])
    for msg in messages:
        if isinstance(msg, dict) and msg.get("role") == "user":
            return msg.get("text", "?")
    return "?"


def extract_output(trace_data: Dict[str, Any]) -> str:
    """
    Extract output text from trace data.

    Args:
        trace_data: Trace data dictionary

    Returns:
        Output text or empty string if not found
    """
    return trace_data.get("outputs", {}).get("text", "")


def accumulate_usage_metadata(
    node: Dict[str, Any], aggregates: Dict[str, Any]
) -> None:
    """
    Recursively accumulate token usage from trace tree.
    Modifies aggregates in-place.

    Args:
        node: Trace node (may have children)
        aggregates: Dictionary to accumulate totals into
    """
    usage = node.get("usageMetadata")
    if usage:
        aggregates["promptTokens"] += usage.get("promptTokens", 0)
        aggregates["completionTokens"] += usage.get("completionTokens", 0)
        aggregates["totalTokens"] += usage.get("totalTokens", 0)
        aggregates["estimatedCost"] += usage.get("estimatedCost", 0.0)

    for child in node.get("children", []):
        accumulate_usage_metadata(child, aggregates)


def assign_unique_ids(
    node: Dict[str, Any], counter: int
) -> Tuple[Dict[str, Any], int]:
    """
    Recursively assign unique IDs to nodes in trace tree.
    Sorts children by timestamp and duration.
    MUTATES node in-place - caller should pass a copy if immutability needed.

    Args:
        node: Trace node to process
        counter: Current ID counter

    Returns:
        Tuple of (processed_node, next_counter)
    """
    node["id"] = f"node_{counter}"
    counter += 1

    if "children" in node:
        # Sort children by timestamp (or begin), then by duration
        node["children"].sort(
            key=lambda child: (
                child.get("timestamp") or child.get("begin") or float("inf"),
                child.get("duration") or 0,
            )
        )
        for child in node["children"]:
            child, counter = assign_unique_ids(child, counter)

    return node, counter


def process_trace_add_ids(trace_dict: Dict[str, Any]) -> Dict[str, Any]:
    """
    Add unique IDs to all nodes in trace tree.
    Creates a deep copy to avoid mutating original data.

    Args:
        trace_dict: Original trace data

    Returns:
        New trace dict with IDs assigned (original unchanged)
    """
    trace_copy = copy.deepcopy(trace_dict)  # FIX: prevent mutation of original
    node, _ = assign_unique_ids(trace_copy, 1)
    return node


class TraceParser:
    """
    Stateless parser for trace data.
    All methods are instance methods for consistency, but could be static.
    """

    def parse_traces(self, resp_str: str) -> List[Tuple[str, Dict[str, Any]]]:
        """Parse trace JSON string."""
        return parse_traces(resp_str)

    def extract_display_name(self, trace_data: Dict[str, Any]) -> str:
        """Extract display name from trace data."""
        return extract_display_name(trace_data)

    def extract_output(self, trace_data: Dict[str, Any]) -> str:
        """Extract output text from trace data."""
        return extract_output(trace_data)

    def accumulate_usage_metadata(
        self, node: Dict[str, Any], aggregates: Dict[str, Any]
    ) -> None:
        """Accumulate token usage from trace tree."""
        return accumulate_usage_metadata(node, aggregates)

    def process_trace_add_ids(self, trace_dict: Dict[str, Any]) -> Dict[str, Any]:
        """Process trace to add unique IDs (creates deep copy)."""
        return process_trace_add_ids(trace_dict)
