import datetime
import json
import time
from functools import wraps
from typing import List, Optional, Union

import dataiku
from backend.agents.agent_connect import (
    AgentConnect,
    DSSToolAgent,
)
from backend.config import (
    agents_as_tools,
    get_charts_generation_llm_id,
    get_default_embedding_llm,
    get_default_llm_id,
    get_enterprise_agent_last_modified,
    get_enterprise_agents,
)
from backend.utils.conv_utils import get_selected_agents_as_objs
from backend.utils.json_utils import extract_json_string
from backend.utils.llm_utils import add_history_to_completion
from backend.utils.logging_utils import get_logger
from backend.utils.user_utils import get_agent_context
from dataikuapi.dss.llm import (
    DSSLLMStreamedCompletionFooter,
)
from flask import current_app, g, has_request_context
from langchain_core.tools import StructuredTool

logger = get_logger(__name__)

# --------------------------------------------------------------------------- #
current_date = datetime.date.today().strftime("%Y-%m-%d")


# --------------------------------------------------------------------------- #
def _freeze(obj):
    """
    Recursively turn lists → tuples, dicts → frozensets, so everything
    becomes hashable.  Fallback for unhashable objects is str().
    """
    if isinstance(obj, (str, int, float, bool, type(None))):
        return obj
    if isinstance(obj, dict):
        # freeze dict as frozenset of (key, frozen value)
        return frozenset((k, _freeze(v)) for k, v in obj.items())
    if isinstance(obj, (list, tuple)):
        # turn list/tuple into tuple of frozen elements
        return tuple(_freeze(x) for x in obj)
    # assume other types are hashable
    try:
        hash(obj)
        return obj
    except TypeError:
        # last resort
        return str(obj)


def request_cached(fn):
    """
    Cache a read-only store method for the current request.
    Key = (fn.__name__, frozen_args, frozen_kwargs)
    """

    @wraps(fn)
    def wrapper(self, *args, **kwargs):
        if not has_request_context():
            return fn(self, *args, **kwargs)
        cache = getattr(g, "_store_cache", None)
        if cache is None:
            cache = {}
            g._store_cache = cache
        key = (fn.__name__, _freeze(args), _freeze(kwargs))
        if key in cache:
            return cache[key]
        result = fn(self, *args, **kwargs)
        cache[key] = result
        return result

    return wrapper


def invalidate_request_cache(*method_names: str):
    """
    After this write method runs, drop any cached entries whose
    fn.__name__ appears in `method_names`.
    """

    def decorator(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            # 1) Invalidate *before* running the write, so any in-method gets hit DB
            if has_request_context() and hasattr(g, "_store_cache"):
                cache = g._store_cache
                for k in list(cache):
                    if k[0] in method_names:
                        del cache[k]

            # 2) Now perform the write
            result = fn(self, *args, **kwargs)

            return result

        return wrapper

    return decorator


def current_request_login() -> Optional[str]:
    if hasattr(g, "authIdentifier") and g.authIdentifier:
        return g.authIdentifier

    return None


# =============================================================================
#  FLASK-SCOPED STORE HELPERS
# =============================================================================
def get_store():
    store = getattr(g, "store", current_app.config.get("STORE"))
    return store


def get_user_base_llm(store=None) -> str:
    """
    Return the current user’s preferred base-model id,
    or the global default if none has been stored yet.
    """
    llm_id = get_default_llm_id()
    if llm_id is None:
        raise ValueError("No User Base LLM ID found")
    return llm_id


def get_charts_generation_llm():
    llm_id = get_charts_generation_llm_id()
    if llm_id is None:
        raise ValueError("No Charts Generation LLM ID found")
    return llm_id


def get_user_base_embedding_llm(store) -> str:
    """
    Return the current user’s preferred base-model id,
    or the global default if none has been stored yet.
    """
    prefs = {}
    try:
        prefs = store.get_preferences()
    except AttributeError:
        pass
    llm_id = prefs.get("baseEmbeddingLlmId") or get_default_embedding_llm()
    logger.info(f"get_user_base_embedding_llm: {llm_id}")
    return llm_id


def get_user_and_ent_agents():
    store = get_store()
    user_agents = store.get_all_agents()
    user_agents.sort(key=lambda a: a.get("createdAt", ""), reverse=True)
    ids = [a["id"] for a in user_agents]
    share_counts = store.get_share_counts(ids)

    current_user = g.authIdentifier
    enterprise_agents = get_enterprise_agents(current_user)

    rows: list[dict] = []
    for a in user_agents:
        rows.extend(_agent_version_rows(a, share_counts.get(a["id"], 0), current_user))

    enriched_enterprise = []
    for a in enterprise_agents:
        owner, last_modified = None, None
        try:
            parts = (a.get("id") or "").split(":")
            if len(parts) == 3:
                project_id, kind, object_id = parts
                is_augmented_llm = kind == "retrieval-augmented-llm"
                owner, last_modified = get_enterprise_agent_last_modified(project_id, object_id, is_augmented_llm)
        except Exception:
            owner, last_modified = None, None
        enriched_enterprise.append(
            {
                "id": a["id"],
                "name": a["name"],
                "description": a["tool_agent_description"],
                "stories_workspace": a.get("stories_workspace"),
                "short_example_queries": a.get("agent_short_example_queries"),
                "example_queries": a.get("agent_example_queries"),
                "owner": owner,
                "published_at": last_modified,
            }
        )
    return {
        "enterpriseAgents": enriched_enterprise,
        "userAgents": rows,
    }


"""
    return {
        "enterpriseAgents": [{"id": a["id"], "name": a["name"]} for a in enterprise_agents],
        "userAgents": [
            {
                "id": a["id"],
                "name": a["name"],
                "status": a.get("status", "active"),
                "createdAt": a.get("createdAt"),
                "indexing": a.get("indexing"),
                "isShared": a["owner"] != g.authIdentifier,
                "shareCount": share_counts.get(a["id"], 0),
                "published_at": a.get("published_at"),
                "published_version": bool(a.get("published_version")),
                "publishing_status": a.get("publishing_status"),
                "hasUnpublishedChanges": _has_unpublished_changes(a),
            }
            for a in user_agents
        ],
    }
    """


def _has_unpublished_changes(agent: dict) -> bool:
    """
    Return **True** iff *any* editable field deviates from the last
    `published_version`
    """
    pv = agent.get("published_version")
    if not pv:  # never published → always consider it as having unpublished changes
        return True

    keys = (
        "name",
        "description",
        "system_prompt",
        "kb_description",
        "sample_questions",
        "llmid",
        "tools",
        "documents",
    )
    return any(agent.get(k) != pv.get(k) for k in keys)


# --------------------------------------------------------------------------- #
#  Helper -- expand ONE agent into 1 or 2 “version rows”
# --------------------------------------------------------------------------- #
def _agent_version_rows(agent: dict, share_count: int, current_user: str) -> list[dict]:
    """
    Return a list with:
      • exactly one object  → if no unpublished changes
      • two objects         → if the draft diverges from the published snapshot

    Each row contains the attributes of the corresponding version so that
    callers can display them independently.
    """
    rows: list[dict] = []

    # -------- PUBLISHED VERSION -------------------------------------------
    if agent.get("published_version"):
        pub = agent["published_version"]

        rows.append(
            {
                "id": agent["id"],
                "version": "published",
                "status": "active",
                "name": pub.get("name", agent["name"]),
                "description": pub.get("description", ""),
                "createdAt": agent.get("published_at"),
                "shareCount": share_count,
                "hasUnpublishedChanges": _has_unpublished_changes(agent),
                "isShared": agent["owner"] != current_user,
                "indexing": agent.get("indexing"),
                "published_at": agent.get("published_at"),
                "publishing_status": agent.get("publishing_status"),
                "published_version": bool(agent.get("published_version")),
                "owner": agent.get("owner"),
                "sample_questions": agent.get("sample_questions", []),
            }
        )

    # -------- DRAFT VERSION ------------------------------------------------
    if _has_unpublished_changes(agent) or not agent.get("published_version"):
        rows.append(
            {
                "id": agent["id"],
                "version": "draft",
                "status": "draft",
                "name": agent["name"],
                "description": agent.get("description", ""),
                "createdAt": agent.get("createdAt"),
                "shareCount": share_count,
                "hasUnpublishedChanges": _has_unpublished_changes(agent),
                "isShared": agent["owner"] != current_user,
                "indexing": agent.get("indexing"),
                "published_at": agent.get("published_at"),
                "publishing_status": agent.get("publishing_status"),
                "published_version": bool(agent.get("published_version")),
                "owner": agent.get("owner"),
                "sample_questions": agent.get("sample_questions", []),
                "short_example_queries": agent.get("short_example_queries", []),
            }
        )
    return rows


def make_json_serializable(obj):
    if isinstance(obj, dict):
        return {k: make_json_serializable(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [make_json_serializable(x) for x in obj]
    if isinstance(obj, (str, int, float, bool, type(None))):
        return obj
    return str(obj)


# =============================================================================
#  AGENT-SELECTION PROMPT (unchanged)
# =============================================================================
def select_agents(store, agent_ids, messages, trace=None):
    enterprise = get_enterprise_agents(g.authIdentifier)
    user_agents = {a["id"]: a for a in store.get_all_agents()}
    query = messages[-1]["content"] if messages else ""
    lines = []
    agents_names = {}
    for aid in agent_ids:
        if cfg := next((e for e in enterprise if e["id"] == aid), None):
            desc = cfg["tool_agent_description"]
            name = cfg["name"]
            agents_names[aid] = cfg["name"]
        elif aid in user_agents:
            ua = user_agents[aid]
            desc = ua.get("description", "")
            name = ua.get("name", "Unnamed Agent") if "published_version" not in ua else ua["published_version"].get("name", "Unnamed Agent")
            agents_names[aid] = name
        else:
            continue
        lines.append(f"• id: {aid}\n  description: {desc}.")
    all_agents = [{"agentId": a, "query": query, "agentName": agents_names[a]} for a in agent_ids]
    if agents_as_tools() or not agent_ids:
        return all_agents, None
    if len(agent_ids) <= 1:
        return all_agents, None

    # Create a subspan if trace is provided
    if trace:
        agent_selection_span = trace.subspan("agent_preselection")
        agent_selection_span.begin(int(time.time() * 1000))
    else:
        agent_selection_span = None

    prompt = (
        "You are an orchestration assistant whose sole job is to decide whether specialised agents should be called, "
        "and if so, which ones, and with what query. "
        "Carefully analyze the chat history, the current user query, and the available agents with their descriptions. "
        "\n\n"
        "# Output format:\n"
        "Produce a JSON object with two fields: 'selected_agents' and 'justification'.\n"
        "- 'selected_agents' must be a list of dictionaries, each with 'agentId' and 'query'.\n"
        "- 'agentId' must match exactly one of the IDs listed under ### Candidate agents.\n"
        "- If no queries are needed, set 'selected_agents' to [] and explain why in 'justification'.\n\n"
        "# Decision criteria:\n"
        "- Generate a query for each agent whenever there is any reasonable chance the agent could help.\n"
        "- Err on the side of including *more agent calls* rather than risking missing a relevant one.\n"
        "- The query should be faithful to the user’s request but may be refined, clarified, or enriched with context from the chat history.\n"
        "- If the user’s query is already sufficient, keep it as-is.\n"
        "- If in doubt, formulate queries to the closest matching agent(s) instead of returning none.\n\n"
        "- Agent calls are executed in parallel. They do not share information with each other and cannot rely on the outputs of other agents. Therefore, each query must be fully self-contained and independent.\n"
        "# Constraints:\n"
        "- Return ONLY valid JSON, no extra text.\n"
        "- Do not invent agents; only use IDs from the provided list.\n"
        "\n"
        "### Candidate agents\n" + "\n".join(lines) + "\n\n"
        "Your JSON response:"
    )
    # Add prompt to trace inputs if available
    if agent_selection_span:
        agent_selection_span.inputs["prompt"] = prompt
        agent_selection_span.inputs["candidate_agent_ids"] = agent_ids

    client = dataiku.api_client()
    project = client.get_default_project()
    llm_id = get_user_base_llm(store)
    comp = project.get_llm(llm_id).new_completion()
    comp.with_message(prompt, role="system")
    comp = add_history_to_completion(completion=comp, messages=messages)
    logger.info(
        "Executing Agents preselection query:\nLLM id=%s\nCompletion Query=%s\nCompletion Settings=%s\n",
        llm_id,
        json.dumps(comp.cq, indent=2, sort_keys=True),
        json.dumps(comp.settings, indent=2, sort_keys=True),
    )
    resp = comp.execute()
    raw = resp.text.strip()
    logger.info("Agents preselection response:\nLLM id=%s\nCompletion Response=%s", llm_id, raw)
    # Add LLM trace to the span if available
    if agent_selection_span and resp.trace:
        agent_selection_span.append_trace(resp.trace)

    candidate = extract_json_string(raw)
    try:
        obj = json.loads(candidate)
        selected = obj.get("selected_agents", [])
        justification = obj.get("justification", "")
        logger.info(f"Agents selection decision: {selected} \n Justification : {justification}")
        if not isinstance(selected, list):
            raise ValueError
    except Exception:
        selected = all_agents
        justification = "Failed to parse selection JSON; using all agents"
    selected = [
        {"agentId": a["agentId"], "query": a["query"], "agentName": agents_names[a["agentId"]]}
        for a in selected
        if "agentId" in a and a["agentId"] in agents_names
    ]
    # Add outputs to trace if available
    if agent_selection_span:
        agent_selection_span.outputs["selected_agents"] = selected
        agent_selection_span.outputs["justification"] = justification
        agent_selection_span.outputs["raw_response"] = raw
        agent_selection_span.end(int(time.time() * 1000))

    return selected, justification


# =============================================================================
#  BUILD AgentConnect
# =============================================================================
def build_agent_connect(
    store, agent_ids, *, user_agents: dict | None = None, draft_mode: bool = False, conv_id: str, tracer=None
):
    """
    Build AgentConnect with appropriate agent versions.

    Args:
        store: Data store
        agent_ids: List of agent IDs to include
        user_agents: Optional pre-loaded user agents dict
        draft_mode: True only when called from edit page (TestAgent)
    """
    items: List[Union[DSSToolAgent, StructuredTool]] = []

    logger.debug("build_agent_connect called with agent_ids=%s, draft_mode=%s", agent_ids, draft_mode)
    agents_objs = get_selected_agents_as_objs(store, agent_ids)
    for agent in agents_objs:
        if agent.get("tool_agent_description"):
            # Enterprise agent
            items.append(
                DSSToolAgent(
                    dss_agent_id=agent["id"],
                    agent_system_instructions=f"Today is {current_date}. {agent['agent_system_instructions']}",
                    tool_agent_description=agent["tool_agent_description"],
                    agent_name=agent["name"],
                )
            )
        else:
            # User agent
            items.append(
                DSSToolAgent(
                    dss_agent_id=agent.get("id"),
                    agent_system_instructions=f"Today is {current_date}. {agent.get('system_prompt', '')}",
                    tool_agent_description=agent.get("description", ""),
                    agent_name=agent.get("name", "Unnamed Agent"),
                )
            )

    return AgentConnect(base_model=get_user_base_llm(store), agents=items)


# =============================================================================
#  ONE-AGENT DIRECT CALL (enterprise)
# =============================================================================
def call_dss_agent_full_conversation(
    agent_id: str, messages: list[dict], user: str, streaming: bool = False, conv_id: str = "", trace=None
):
    """
    Used for *enterprise* agents only.
    User agents are now handled by UserAgent directly.
    """
    if ":" not in agent_id:
        raise ValueError(f"Invalid agent identifier: {agent_id}")

    client = dataiku.api_client()
    project_key, short = agent_id.split(":", 1)
    llm = client.get_project(project_key).get_llm(short)
    comp = llm.new_completion()
    context = get_agent_context(user, conv_id)
    comp.with_context(context)
    for m in messages:
        comp = comp.with_message(m["content"], role=m["role"])
    logger.info(
        "Calling enterprise agent id=[%s]\ncompletion_query=%s\nsettings=%s",
        agent_id,
        json.dumps(comp.cq, indent=2, sort_keys=True),
        json.dumps(comp.settings, indent=2, sort_keys=True),
    )

    if streaming:
        for chunk in comp.execute_streamed():
            if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                # Extract the trace data from the footer
                trace_data = chunk.trace
                yield {"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}}
                if trace_data:
                    trace.append_trace(trace_data)
            else:
                yield {"chunk": chunk.data}
    else:
        reply = comp.execute()
        return reply.text


# =============================================================================
#  FILE UTILITIES
# =============================================================================
def get_file_size(file_obj) -> int:
    """
    Get the size of an uploaded file in bytes.

    Args:
        file_obj: Flask file object from request.files

    Returns:
        int: File size in bytes
    """
    try:
        # Save current position
        current_pos = file_obj.tell()

        # Seek to end to get file size
        file_obj.seek(0, 2)  # Seek to end
        size = file_obj.tell()

        # Restore original position
        file_obj.seek(current_pos)

        return size
    except Exception as e:
        logger.warning(f"Failed to get file size for {file_obj.filename}: {e}")
        return 0
