import json
import threading
from typing import Any, Dict, List, Tuple

import dataiku
from flask import Blueprint
from traces_explorer.backend.utils.dataiku_api import dataiku_api
from traces_explorer.backend.utils.logging import logger
from traces_explorer.backend.utils.response_utils import return_ok

load_traces_lock = threading.Lock()

traces_blueprint = Blueprint("traces", __name__, url_prefix="/traces")
config = dataiku_api.webapp_config
llm_response_column = config.get("llm_responses_column")
llm_response_dataset = config.get("llm_responses_dataset")
traces: List[Dict[str, Any]] = []


def assign_unique_ids(node: Dict[str, Any], counter: int) -> Tuple[Dict[str, Any], int]:
    node["id"] = f"node_{counter}"
    counter += 1
    if "children" in node:
        for child in node["children"]:
            child, counter = assign_unique_ids(child, counter)
    return node, counter


def process_trace_add_ids(trace_dict: Dict[str, Any]) -> Dict[str, Any]:
    node, _ = assign_unique_ids(trace_dict, 1)
    return node


def accumulate_usage_metadata(node: Dict[str, Any], aggregates: Dict[str, Any]) -> None:
    usage = node.get("usageMetadata")
    if usage:
        aggregates["promptTokens"] += usage.get("promptTokens", 0)
        aggregates["completionTokens"] += usage.get("completionTokens", 0)
        aggregates["totalTokens"] += usage.get("totalTokens", 0)
        aggregates["estimatedCost"] += usage.get("estimatedCost", 0.0)
    for c in node.get("children", []):
        accumulate_usage_metadata(c, aggregates)


def read_column_content(dataset_name: str, column_name: str) -> List[str]:
    dataset = dataiku.Dataset(dataset_name)
    data_frame = dataset.get_dataframe()
    return data_frame[column_name].tolist()


def parse_traces(resp_str: str) -> List[Tuple[str, Dict[str, Any]]]:
    data = json.loads(resp_str)
    if "trace" in data and isinstance(data["trace"], dict):
        return [("", data["trace"])]
    named_traces = data.get("named_traces") or data.get("namedTraces")
    if named_traces and isinstance(named_traces, list):
        named_traces = data["named_traces"]
        return [(trace.get("trace_name"), trace.get("trace_value")) for trace in named_traces]

    return [("", data)]


def extract_display_name(trace_data: Dict[str, Any]) -> str:
    """
    Extract the display name from the trace data, specifically
    the text of the first user message, if available.
    """
    messages = trace_data.get("inputs", {}).get("messages", [])
    for msg in messages:
        if isinstance(msg, dict) and msg.get("role") == "user":
            return msg.get("text", "?")
    return "?"


def load_traces() -> None:
    """
    Loads and parses LLM response traces, then populates the global 'traces' list.
    """
    global traces
    with load_traces_lock:
        traces.clear()

        # If no dataset or column is selected, we can't load any traces
        if not llm_response_dataset or not llm_response_column:
            logger.info("No dataset or column selected, skipping traces loading")
            return

        responses = read_column_content(llm_response_dataset, llm_response_column)
        for response_index, response_str in enumerate(responses):
            try:
                for trace_index, (trace_name, trace_data) in enumerate(parse_traces(response_str)):
                    display_name = extract_display_name(trace_data)
                    traces.append(
                        {
                            "id": f"{response_index}_{trace_index}",
                            "value": trace_data,
                            "displayName": display_name,
                            "traceName": trace_name,
                            "begin": trace_data.get("begin"),
                            "duration": trace_data.get("duration"),
                        }
                    )
            except json.JSONDecodeError as exc:
                logger.warn("Skipping a response due to JSON parsing error at index %d: %s", response_index, exc)


load_traces()


@traces_blueprint.route("/reload", methods=["GET"])
def reload_traces():
    load_traces()
    return return_ok(data={"message": "Traces reloaded"})


@traces_blueprint.route("/list", methods=["GET"])
def list_traces():
    response_data = []
    for t in traces:
        response_data.append(
            {
                "id": t["id"],
                "name": t["displayName"],
                "begin": t["begin"],
                "duration": t["duration"],
                "traceName": t["traceName"],
            }
        )
    return return_ok(data=response_data)


@traces_blueprint.route("/get_trace/<trace_id>", methods=["GET"])
def get_trace(trace_id: str):
    for t in traces:
        if t["id"] == trace_id:
            trace_with_ids = {}
            trace_with_ids["id"] = t["id"]
            trace_with_ids["name"] = t["displayName"]
            trace_with_ids["begin"] = t["begin"]
            trace_with_ids["duration"] = t["duration"]
            trace_with_ids["traceName"] = t["traceName"]
            trace_with_ids["parentNode"] = process_trace_add_ids(t["value"])
            aggregates = {"promptTokens": 0, "completionTokens": 0, "totalTokens": 0, "estimatedCost": 0.0}
            accumulate_usage_metadata(trace_with_ids["parentNode"], aggregates)
            trace_with_ids["overallTotalPromptTokens"] = aggregates["promptTokens"]
            trace_with_ids["overallTotalCompletionTokens"] = aggregates["completionTokens"]
            trace_with_ids["overallTotalTokens"] = aggregates["totalTokens"]
            trace_with_ids["overallEstimatedCost"] = aggregates["estimatedCost"]
            trace_with_ids["traceName"] = t["traceName"]
            return return_ok(trace_with_ids)
    return return_ok(data={}, message="Trace not found")
