from enum import Enum
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple

import dataiku
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.dataiku_utils import find_recipe
from common.llm_assist.logging import logger
from dataiku.core.dataset import Schema
from dataiku.core.knowledge_bank import KnowledgeBank
from dataikuapi.dss.knowledgebank import DSSKnowledgeBankListItem

webapp_config = dataiku_api.webapp_config


class EmbeddingRecipeType(str, Enum):
    NLP_LLM_RAG_EMBEDDING = "nlp_llm_rag_embedding"
    EMBED_DOCUMENTS = "embed_documents"


@lru_cache(maxsize=None)
def get_knowledge_bank_full_name(knowledge_bank_id):
    if knowledge_bank_id:
        knowledge_bank = KnowledgeBank(knowledge_bank_id, project_key=dataiku_api.default_project_key)
        return knowledge_bank.full_name
    else:
        return None


@lru_cache(maxsize=None)
def get_core_knowledge_bank(knowledge_bank_id):
    if knowledge_bank_id:
        return dataiku_api.default_project.get_knowledge_bank(knowledge_bank_id).as_core_knowledge_bank()
    else:
        return None


@lru_cache(maxsize=None)
def get_knowledge_bank_name(id: Optional[str]) -> Optional[str]:
    if id is None:
        return None
    project = dataiku_api.default_project

    short_id = id
    if "." in id:
        (project_key, short_id) = id.split(".", 1)
    for kb in project.list_knowledge_banks():
        item: DSSKnowledgeBankListItem = kb
        if item.id == short_id:
            name: str = kb.name
            return name
    return None


def get_knowledge_bank_retriever(
    knowledge_bank_id: Optional[str] = None,
    filters: Optional[Dict[str, List[Any]]] = None,
    search_type: Optional[str] = "",
) -> Any:
    from answers.backend.utils.parameter_helpers import get_retriever_search_kwargs, load_knowledge_bank_parameters
    knowledge_bank_parameters, __ = load_knowledge_bank_parameters()
    retrieval_parameters = knowledge_bank_parameters["retrieval_parameters"]
    project = dataiku_api.default_project
    vector_db_type = get_vector_db_type(project, knowledge_bank_id)
    retriever_search_kwargs = get_retriever_search_kwargs(retrieval_parameters, vector_db_type, filters)

    logger.info(f"""
    Vector DB type: {vector_db_type}
    Search type: {search_type}
    Retriever search kwargs: {retriever_search_kwargs}
    """)
    kb = KnowledgeBank(knowledge_bank_id, project_key=dataiku_api.default_project_key)
    if vector_db_type == "AZURE_AI_SEARCH":
        logger.debug(f"""Using Azure AI Search k value of {retrieval_parameters["k"]}""")
        retriever = kb.as_langchain_retriever(
        search_type=search_type,
        search_kwargs=retriever_search_kwargs,
        k=retrieval_parameters["k"]
    )
    else:
        retriever = kb.as_langchain_retriever(
        search_type=search_type,
        search_kwargs=retriever_search_kwargs,
    )
    return retriever


@lru_cache(maxsize=None)
def get_vector_db_type(project, knowledge_bank_id: Optional[str] = None):
    if knowledge_bank_id:
        return project.get_knowledge_bank(knowledge_bank_id).as_core_knowledge_bank()._get()["vectorStoreType"]
    return None


@lru_cache(maxsize=None)
def get_knowledge_bank_info(knowledge_bank_id: str) -> Tuple[str, Optional[Schema]]:
    project = dataiku_api.default_project
    graph = project.get_flow().get_graph()
    recipe_json = find_recipe(graph.data, knowledge_bank_id)
    # "embed_documents", "nlp_llm_rag_embedding" default to nlp_rag
    recipe_type = recipe_json.get("subType")
    schema = None
    if recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING:
        dataset_name = recipe_json.get("predecessors")[0]
        dataset = dataiku.Dataset(project_key=dataiku_api.default_project_key, name=dataset_name)
        schema = dataset.read_schema()
    return recipe_type, schema
