import type { Trace, TraceNode, UsageMetadata } from '@/common/interface/trace'
import { LOCAL_TRACE_STORAGE_KEY } from '@/utils/config'

export function prepareTrace(
  json: any,
  name: string = 'Local pasted trace'
): { trace: Trace | null; error: string | null } {
  let trace: Trace | null = null
  let error = null

  try {
    const obj = JSON.parse(json)
    let parentNode: TraceNode | null = null

    if (obj.type === 'span' || obj.type === 'event') {
      parentNode = obj
    } else if (obj.trace && (obj.trace.type === 'span' || obj.trace.type === 'event')) {
      parentNode = obj.trace
    }

    if (!parentNode) {
      error = "No valid 'trace' object found in JSON."
      return { trace, error }
    }

    if (!parentNode.begin) {
      parentNode.begin = new Date().toISOString()
    }
    if (parentNode?.inputs?.messages) {
      name = extractDisplayName(parentNode)
    }

    trace = generateTraceMetadata({
      id: LOCAL_TRACE_STORAGE_KEY + Date.now(),
      begin: parentNode.begin || new Date().toISOString(),
      duration: parentNode.duration || 0,
      parentNode: parentNode,
      name: name
    })

    return { trace, error }
  } catch (err: any) {
    error = 'Invalid JSON: ' + (err.message || err)
    return { trace, error }
  }
}

export function generateTraceMetadata(trace: Trace): Trace | null {
  if (!trace || !trace.parentNode) return null

  const [node] = assignUniqueIds(trace.parentNode, 1)
  const agg = {
    promptTokens: 0,
    completionTokens: 0,
    totalTokens: 0,
    estimatedCost: 0
  }

  accumulateUsageMetadata(node, agg)

  trace.overallTotalPromptTokens = agg.promptTokens
  trace.overallTotalCompletionTokens = agg.completionTokens
  trace.overallTotalTokens = agg.totalTokens
  trace.overallEstimatedCost = agg.estimatedCost
  return { ...trace, parentNode: node } as Trace
}

function assignUniqueIds(node: TraceNode, counter = 1): [TraceNode, number] {
  node.id = 'node_' + counter
  counter++

  if (Array.isArray(node.children)) {
    for (let i = 0; i < node.children.length; i++) {
      ;[node.children[i], counter] = assignUniqueIds(node.children[i], counter)
    }
  }

  return [node, counter]
}

function accumulateUsageMetadata(node: TraceNode, aggregates: UsageMetadata) {
  const usage = node.usageMetadata

  if (usage) {
    aggregates.promptTokens += usage.promptTokens || 0
    aggregates.completionTokens += usage.completionTokens || 0
    aggregates.totalTokens += usage.totalTokens || 0
    aggregates.estimatedCost += usage.estimatedCost || 0
  }

  if (Array.isArray(node.children)) {
    for (const c of node.children) {
      accumulateUsageMetadata(c, aggregates)
    }
  }
}

function extractDisplayName(traceData: TraceNode): string {
  const messages = traceData?.inputs?.messages

  if (Array.isArray(messages)) {
    for (const msg of messages) {
      if (msg && typeof msg === 'object' && msg.role === 'user') {
        return msg.text || '?'
      }
    }
  }

  return '?'
}