import time
from contextvars import ContextVar
from enum import Enum
from typing import Any, Dict, Optional

from common.llm_assist.logging import logger
from dataiku.llm.tracing import SpanBuilder, new_trace

TIME_FORMAT = "%Y-%d-%m %H:%M:%S.%f"


class LLMStepName(Enum):
    UNKNOWN = "Unknown"
    DKU_ANSWERS_QUERY = "Dataiku Answers Query"
    DKU_ANSWERS_UPLOAD_FILE = "Dataiku Answers Upload File"
    DKU_AGENT_CONNECT_QUERY = "Dataiku Agent connect query"
    DKU_AGENT_CONNECT_UPLOAD = "Dataiku Agent connect upload"
    CONVERSATION_TITLE = "Get conversation title"
    SUMMARY_TITLE = "Get summary title"
    START_MEDIA_QA_CONVERSATION = "Start a media QA conversation"
    DECISION_SELF_SERVICE = "Get decision as JSON: self-service"
    IMAGE_SUMMARY = "Visual Question Answering (VQA) Chain: get image summary"
    DOC_AS_IMAGE_SUMMARY = "Visual Question Answering (VQA) Chain: get summary"
    DOC_AS_IMAGE_ANSWER = "Visual Question Answering (VQA) Chain: answer"
    TEXT_EXTRACTION_SUMMARY = "Text Extraction: summary"
    TEXT_EXTRACTION_ANSWER = "Text Extraction: answer"
    KB_AUTO_FILTERING = "KBRetrievalChain : auto filtering"
    KB_RETRIEVAL_QUERY = "KBRetrievalChain : retrieval query"
    KB_ANSWER = "KBRetrievalChain : answer"
    DB_RETRIEVAL_QUERY = "DBRetrievalChain : retrieval query"
    DB_RETRIEVAL_GRAPH = "DBRetrievalChain : retrieval graph"
    DB_FIX_RETRIEVAL_QUERY = "DBRetrievalChain : fix retrieval query"
    DB_ANSWER = "DBRetrievalChain : answer"
    DECISION_IMAGE_GENERATION = "Get decision as JSON: image generation"
    IMAGE_GENERATION = "Image Generation Chain (does not count tokens)"
    NO_RETRIEVAL = "No Retrieval Used Chain"
    AGENT_COMPLETION = "Agent Completion Chain"
    DECISION_AGENT = "Get decision as JSON: agent choice"
    DSS_AGENT = "DSS Agent"


main_trace_var: ContextVar[Optional[SpanBuilder]] = ContextVar("main_trace", default=None)

context_conv_id: ContextVar[Optional[str]] = ContextVar("context_conv_id", default=None)

def init_user_trace(main_trace_name: str = LLMStepName.DKU_ANSWERS_QUERY.name, begin_time: Optional[int] = None) -> None:
    main_trace = main_trace_var.get()
    if not main_trace:
        init_trace = new_trace(main_trace_name)
        with init_trace:
            init_trace.begin(begin_time or int(time.time() * 1000))
            main_trace_var.set(init_trace)
    else:
        logger.debug(f"Trace already exists: {main_trace.span.get('name', '')}")


def add_llm_step_trace(trace_value: Dict[str, Any]) -> None:
    main_trace = main_trace_var.get()
    if main_trace:
        main_trace.append_trace(trace_to_append=trace_value)
        main_trace_var.set(main_trace)
    else:
        logger.warn("No main trace found")


def log_llm_step_trace() -> None:
    main_trace = main_trace_var.get()
    if main_trace:
        logger.debug(f"LLM steps: {main_trace.to_dict()}")
    else:
        logger.warn("No main trace found")


def get_main_trace() -> Optional[SpanBuilder]:
    main_trace = main_trace_var.get()
    return main_trace if main_trace else None

def get_main_trace_dict(inputs: str, outputs: str) -> Dict[str, Any]:
    trace_dict = {}
    main_trace = main_trace_var.get()
    if main_trace:
        main_trace.end(int(time.time() * 1000))
        trace_dict = main_trace.to_dict()
        trace_dict.setdefault("inputs", {}).update({"messages":[{"text": inputs, "role": "user"}]})
        trace_dict.setdefault("outputs", {}).update({"text": outputs})
        return trace_dict
    logger.warn("No main trace found")
    return trace_dict

def get_conv_id() -> Optional[str]:
    """Read the current conversation id (None if not set)."""
    return context_conv_id.get()

def set_conv_id(conv_id: Optional[str]):
    """
    Set/override the conversation id for the current context.
    Returns a token that MUST be used to reset after use.
    """
    return context_conv_id.set(conv_id)

def reset_conv_id(token) -> None:
    """Reset to the previous value using the token returned by set_conv_id()."""
    context_conv_id.reset(token)

