from __future__ import annotations

from typing import Optional

import dataiku
from backend.utils.dss_utils import compare_versions
from backend.utils.logging_utils import get_logger
from dataikuapi.dss.agent import DSSAgent
from dataikuapi.dss.managedfolder import DSSManagedFolder
from dataikuapi.dss.project import DSSProject
from dataikuapi.dss.recipe import DSSRecipe

logger = get_logger(__name__)

def list_project_agent_tools(project: DSSProject):
    # Check DSS version to determine if we can include shared tools
    client = dataiku.api_client()
    dss_version = client.get_instance_info().raw.get("dssVersion", "0.0.0")
    if dss_version == "0.0.0":
        logger.warning("Could not retrieve DSS version")
    if (compare_versions(dss_version, "14.3.0")) >= 0:
        return project.list_agent_tools(include_shared=True, as_type="listitems")
    return [
        {"id": tool.id, "name": tool.name, "type": tool.type, "projectKey": project.project_key}
        for tool in project.list_agent_tools(as_type="listitems")
    ]


def get_ua_project(agent_id: str) -> DSSProject:
    client = dataiku.api_client()
    if "ac_user_agent_" not in agent_id:
        raise Exception("Agent project key not found")
    logger.info(f"Using existing project {agent_id}")
    return client.get_project(agent_id)


def create_embed_doc_recipe(project: DSSProject, zone: str, folder_id: str, llm_id: str, embedding_llm: str):
    rc = project.new_recipe("embed_documents", f"compute_documents_embedded_{zone}")
    rc.with_input(folder_id)
    rc.with_output_knowledge_bank(f"documents_embedded_{zone}", embedding_llm)
    rc.with_vlm(llm_id)  # should  this be an option text extraction or vlm?
    r = rc.create()
    # r.run()
    return r


def create_aug_llm(project, name, kb_id, llm_id):
    return project.create_retrieval_augmented_llm(name, kb_id, llm_id)


def create_kb_tool(
    project: DSSProject, kb_id: str, kb_desc: str = "", maxDocuments: int = 4, zone_name="DRAFT"
) -> DSSAgent:
    logger.info(f"Creating KB tool for KB id {kb_id}")
    vector_search_tool_creator = project.new_agent_tool("VectorStoreSearch", "KB Tool " + zone_name)
    vector_search_tool_creator.with_knowledge_bank(kb_id)
    vector_search_tool = vector_search_tool_creator.create()
    # Update settings after creation
    if kb_desc or maxDocuments != 4:
        tool_settings = vector_search_tool.get_settings()
        tool_settings_raw = tool_settings.get_raw()
        tool_settings_raw["additionalDescriptionForLLM"] = kb_desc
        tool_settings_raw["params"]["maxDocuments"] = maxDocuments
        tool_settings.save()

    return vector_search_tool


def get_kb_tool_in_zone(project: DSSProject, zone_name: str):
    try:
        for tool in list_project_agent_tools(project):
            logger.info(f"get_kb_tool_in_zone, checking tool {tool.id} of type {tool.type}")
            if tool.type == "VectorStoreSearch":
                if zone_name in tool.name:
                    logger.info(f"get_kb_tool_in_zone, found KB tool in zone '{zone_name}': {tool.id}")
                    return tool.to_agent_tool()
        logger.warning(f"No KB tool found in zone '{zone_name}'")
        return None

    except Exception as e:
        logger.error(f"Error getting KB tool in zone '{zone_name}': {e}")
        return None


def create_visual_agent(project: DSSProject, agent_name: str, tools_ids: list[str], llm_id: str, prompt: str):
    logger.info(f"Creating visual agent {agent_name} with tools {tools_ids} and llm {llm_id}")
    agent = project.create_agent(agent_name, "TOOLS_USING_AGENT")
    return update_visual_agent(agent, tools_ids, llm_id, prompt)


def update_visual_agent(agent: DSSAgent, tools_ids: list[str], llm_id: str, prompt: str):
    client = dataiku.api_client()
    logger.info(f"Updating visual agent {agent.id} with tools {tools_ids} and llm {llm_id}")
    # Set the agent settings
    agent_settings = agent.get_settings()
    s_raw = agent_settings.get_raw()
    for v in s_raw["versions"]:
        if v["versionId"] == "v1":
            # no direct setter for prompt for now
            v["toolsUsingAgentSettings"]["systemPromptAppend"] = prompt
            v["toolsUsingAgentSettings"]["tools"] = []
    vsettings = agent_settings.get_version_settings("v1")
    vsettings.llm_id = llm_id
    for tool_id in tools_ids:
        if ":" in tool_id:
            project_id, agent_tool_id = tool_id.split(":", 1)
            source_project = client.get_project(project_id)
        else:
            agent_tool_id = tool_id
            source_project = client.get_default_project()

        tool = source_project.get_agent_tool(agent_tool_id)
        vsettings.add_tool(tool)
    agent_settings.save()

    return agent


def get_agent_details(project: DSSProject, zone_name: str) -> dict:
    """get details from visual agent and kb tool in the zone

    Args:
        project (DSSProject): project containing the visual agent
        zone_name (str): zone name

    Returns:
        dict: dict with keys: llm_id, tools (list of tool ids), kb_description, system_prompt, name
    """
    agent = get_visual_agent_in_zone(project, zone_name)
    settings = agent.get_settings()
    v = settings.get_version_settings("v1")
    s_raw = settings.get_raw()
    details = {
        "name": s_raw["name"].strip().replace(f" {zone_name}", ""),
    }
    details["llm_id"] = v.llm_id
    kb_tool = get_kb_tool_in_zone(project, zone_name)
    # Quick workaround to get tool ids with the supposition that they are only taken from the current project as thats how the settings are saved
    # TODO: fix this when we have cross project tools -- REMOVE THE PARTITION
    if kb_tool:
        details["tools"] = [tool["toolRef"].partition(".")[2] for tool in v.tools if kb_tool.id != tool["toolRef"]]
        tool_settings_raw = kb_tool.get_settings().get_raw()
        details["kb_description"] = tool_settings_raw.get("additionalDescriptionForLLM", "")
    else:
        details["tools"] = [tool["toolRef"].replace(".", ":", 1) for tool in v.tools]
        details["kb_description"] = ""
    for v in s_raw["versions"]:
        if v["versionId"] == "v1":
            # no direct access for prompt for now
            details["system_prompt"] = v.get("toolsUsingAgentSettings", {}).get("systemPromptAppend", "")
    return details


def get_visual_agent_in_zone(project: DSSProject, zone_name: str) -> Optional[DSSAgent]:
    try:
        if not (zone_id := get_zone_id_by_name(project, zone_name)):
            return None

        flow, fallback = project.get_flow(), None
        for mod in project.list_saved_models():
            if mod["savedModelType"] != "TOOLS_USING_AGENT":
                continue
            try:
                agent = project.get_agent(mod["id"])
                if (z := flow.get_zone_of_object(agent)) and z.id == zone_id:
                    # Exact match
                    if zone_name in mod["name"]:
                        return agent
                    # First found fallback
                    fallback = fallback or agent
            except Exception:
                continue
        return fallback
    except Exception as e:
        logger.error(f"Error getting visual agent in zone '{zone_name}': {e}")
        return None


def get_zone_id_by_name(project: DSSProject, zone_name: str) -> Optional[str]:
    try:
        flow = project.get_flow()
        zones = flow.list_zones()

        for zone in zones:
            if zone.name == zone_name:
                logger.info(f"get_zone_id_by_name, found zone '{zone_name}' with ID: {zone.id}")
                return zone.id

        logger.warning(f"Zone '{zone_name}' not found in project")
        return None

    except Exception as e:
        logger.error(f"Error getting zone ID for '{zone_name}': {e}")
        return None


def get_managed_folder_in_zone(
    project: DSSProject, folder_name: str, zone_name: Optional[str] = None
) -> Optional[DSSManagedFolder]:
    try:
        zone_id = None
        if zone_name:
            zone_id = get_zone_id_by_name(project, zone_name)
            if not zone_id:
                logger.warning(f"Zone '{zone_name}' not found, cannot get folder")
                return None

        flow = project.get_flow()

        for folder_info in project.list_managed_folders():
            if folder_info["name"] == folder_name:
                folder = project.get_managed_folder(folder_info["id"])

                if not zone_name:
                    return folder

                try:
                    folder_zone = flow.get_zone_of_object(folder)
                    if folder_zone and folder_zone.id == zone_id:
                        logger.info(f"get_managed_folder_in_zone, found folder '{folder_name}' in zone '{zone_name}'")
                        return folder
                except Exception as e:
                    logger.info(f"Could not get zone for folder {folder_name}: {e}")

        logger.warning(f"Folder '{folder_name}' not found in zone '{zone_name}'")
        return None

    except Exception as e:
        logger.error(f"Error getting managed folder: {e}")
        return None


def get_recipe_in_zone(project: DSSProject, recipe_type: str, zone_name: str) -> Optional[DSSRecipe]:
    zone_id = None

    try:
        flow = project.get_flow()
        zones = flow.list_zones()

        # Find zone ID by name
        for zone in zones:
            if zone.name == zone_name:
                zone_id = zone.id
                break

        if not zone_id:
            logger.warning(f"Zone '{zone_name}' not found")
            return None

        # Find recipe in the zone
        for item in project.list_recipes():
            if item["type"] == recipe_type:
                recipe = project.get_recipe(item["name"])
                recipe_zone = flow.get_zone_of_object(recipe)

                if recipe_zone and recipe_zone.id == zone_id:
                    logger.info(f"get_recipe_in_zone, found {recipe_type} recipe '{recipe.name}' in zone '{zone_name}'")
                    return recipe

        logger.warning(f"No {recipe_type} recipe found in zone '{zone_name}'")
        return None

    except Exception as e:
        logger.error(f"Error finding recipe in zone '{zone_name}': {e}")
        return None


def get_augmented_llm_in_zone(project: DSSProject, zone_name: str) -> Optional[str]:
    try:
        # List all LLMs in the project
        for llm_def in project.list_llms():
            llm_id = llm_def.get("id", "")

            # Augmented LLMs typically have "retrievalaugmented" in their ID
            if "retrieval-augmented-llm" in llm_id:
                # Check if this LLM matches our zone
                # This assumes naming convention like:
                # - retrievalaugmented_draft_xxx
                # - retrievalaugmented_published_xxx
                # OR check the KB name it references
                llm_name = llm_def.description
                if zone_name.lower() in llm_name.lower():
                    full_id = f"{project.project_key}:{llm_id}"
                    logger.info(f"get_augmented_llm_in_zone, found augmented LLM for zone '{zone_name}': {full_id}")
                    return full_id

                # Alternative: check which KB it uses
                # You could inspect the LLM definition to see which KB it references
                # and match that to the zone

        # If no zone-specific LLM found, log warning
        logger.warning(f"No augmented LLM found for zone '{zone_name}'")

        return None

    except Exception as e:
        logger.error(f"Error finding augmented LLM in zone '{zone_name}': {e}")
        return None


def get_kb_in_zone(project: DSSProject, zone_name: str) -> Optional[str]:
    try:
        # List all knowledge banks in the project
        for kb_def in project.list_knowledge_banks():
            kb_id = kb_def.get("id", "")
            kb_name = kb_def.get("name", "")

            # Check if this KB matches our zone
            if zone_name.lower() in kb_name.lower():
                return kb_id
        logger.warning(f"No knowledge bank found for zone '{zone_name}'")
        return None

    except Exception as e:
        logger.error(f"Error finding knowledge bank in zone '{zone_name}': {e}")
        return None


def update_project_name(project: DSSProject, new_name: str):
    logger.info(f"Renaming project {project.project_key} to '{new_name}'")
    meta = project.get_metadata()
    meta["label"] = new_name
    project.set_metadata(meta)


def update_agent_name_in_zone(project: DSSProject, zone_name: str, new_name: str):
    logger.info(f"Renaming agent in zone '{zone_name}' to '{new_name}'")
    for mod in project.list_saved_models():
        if mod["savedModelType"] == "TOOLS_USING_AGENT" and zone_name in mod["name"]:
            sm = project.get_saved_model(mod["id"])
            ds = sm.get_settings()
            ds_raw = ds.get_raw()
            ds_raw["name"] = new_name + " " + zone_name
            ds.save()
            logger.info(f"Updated agent name: '{new_name}'")
            break


def current_project_owner_login() -> Optional[str]:
    """
    Returns the login of the current project owner.
    """
    try:
        return (dataiku.api_client().get_default_project().get_permissions() or {}).get("owner")
    except Exception:
        return None
