
from typing import Any, Dict, List, Optional

import dataiku
from answers.backend.utils.knowledge_filters import get_knowledge_bank_filtering_settings
from answers.solutions.knowledge_bank import get_vector_db_type
from common.backend.models.base import (
    KnowledgeBankParameters,
    KnowledgeBankRetrievalParams,
    PerKnowledgeBankParameters,
)
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger

project = dataiku_api.default_project
webapp_config = dataiku_api.webapp_config


def get_knowledge_bank_metas_choices(config, is_multi_select=False):
    knowledge_bank_id = config.get("knowledge_bank_id")
    filer_config = get_knowledge_bank_filtering_settings(
        knowledge_bank_id, False)

    choices = {
        "choices": [{"value": meta, "label": meta} for meta in filer_config['filter_metadata']]
    }
    if is_multi_select:
        return choices
    choices["choices"].extend([{"value": "", "label": "None"}])
    return choices


def do(payload, config, plugin_config, inputs):
    parameter_name = payload["parameterName"]
    client = dataiku.api_client()
    current_project = client.get_default_project()

    if parameter_name == "llm_id":

        return {
            "choices": [
                {"value": llm.get("id"), "label": llm.get("friendlyName")} for llm in current_project.list_llms() if llm.get('type') != 'RETRIEVAL_AUGMENTED'
            ]
        }
    elif parameter_name == "knowledge_bank_id":
        return {
            "choices": [{"value": "", "label": "None"}] + [
                {"value": kb.get("id"), "label": kb.get("name")} for kb in current_project.list_knowledge_banks()
            ]
        }
    elif parameter_name in ("knowledge_sources_filters", "knowledge_sources_displayed_metas"):
        return get_knowledge_bank_metas_choices(config, is_multi_select=True)
    elif parameter_name in ("knowledge_source_url", "knowledge_source_title", "knowledge_source_thumbnail"):
        return get_knowledge_bank_metas_choices(config)

    else:
        return {
            "choices": [
                {
                    "value": "wrong",
                    "label": f"Problem getting the name of the parameter.",
                }
            ]
        }


def load_n_top_sources_to_log():
    """
    Loads the parameter 'n_top_sources_to_log' based on the webapp configuration.

    :param config: dict: Webapp configuration.

    :returns: n_top_sources_to_log: int: The number of sources to log based on the webapp configuration
    """
    DEFAULT_N_SOURCES_TO_LOG = -1
    filter_logged_sources = webapp_config.get("filter_logged_sources", True)
    if filter_logged_sources:
        n_top_sources_to_log = webapp_config.get(
            "n_top_sources_to_log", DEFAULT_N_SOURCES_TO_LOG)
    else:
        n_top_sources_to_log = DEFAULT_N_SOURCES_TO_LOG
    return n_top_sources_to_log

def load_knowledge_bank_parameters():
    """
    Loads the parameters associated with the knowledge bank.

    :returns: all_knowledge_bank_parameters: dict: The parameters of the knowledge banks.
    :returns: per_knowledge_bank_parameters: dict: The parameters associated with each knowledge bank.
    """
    # TODO: Adapt when several knowledge banks will be connected
    per_knowledge_bank_parameters: PerKnowledgeBankParameters = {}
    knowledge_bank_ids = []
    knowledge_bank_vector_db_types = []
    knowledge_bank_descriptions = []
    knowledge_bank_weights = []
    knowledge_bank_custom_names = []
    knowledge_bank_search_parameters = {}


    # Hard coded params when only 1 knowledge bank is available
    knowledge_bank_id = webapp_config.get("knowledge_bank_id")
    knowledge_bank_custom_name = webapp_config.get(
        "knowledge_bank_custom_name", knowledge_bank_id)
    knowledge_bank_weight = None
    knowledge_bank_description = webapp_config.get(
        "knowledge_bank_description", "")

    if knowledge_bank_id not in ["None", None]:
        knowledge_bank_ids.append(knowledge_bank_id)
        knowledge_bank_weights.append(knowledge_bank_weight)
        knowledge_bank_descriptions.append(knowledge_bank_description)
        knowledge_bank_custom_names.append(knowledge_bank_custom_name)

    if knowledge_bank_ids:
        PARAMETER_PREFIX = "knowledge_retrieval_"
        for kb_id, kb_weight, kb_description, kb_name in zip(knowledge_bank_ids, knowledge_bank_weights, knowledge_bank_descriptions, knowledge_bank_custom_names):
            kb_vector_db_type = get_vector_db_type(project, kb_id)
            knowledge_bank_vector_db_types.append(kb_vector_db_type)
            if kb_id not in per_knowledge_bank_parameters:
                per_knowledge_bank_parameters[kb_id] = {}
            per_knowledge_bank_parameters[kb_id] = KnowledgeBankParameters(id=kb_id,
                                                                           description=kb_description,
                                                                           name=kb_name,
                                                                           weight=kb_weight,
                                                                           vector_db_type=kb_vector_db_type)

        candidate_parameters = KnowledgeBankRetrievalParams.__annotations__.keys()
        for parameter_name in candidate_parameters:
            if isinstance(parameter_name, str):
                parameter_value = webapp_config.get(
                    PARAMETER_PREFIX + parameter_name)
                knowledge_bank_search_parameters[parameter_name] = parameter_value
        logger.debug(knowledge_bank_search_parameters)

    all_knowledge_bank_parameters = {
        "knowledge_bank_ids": knowledge_bank_ids,
        "knowledge_bank_custom_names": knowledge_bank_custom_names,
        "knowledge_bank_descriptions": knowledge_bank_descriptions,
        "knowledge_bank_weights": knowledge_bank_weights,
        "knowledge_bank_vector_db_types": knowledge_bank_vector_db_types,
        "retrieval_parameters": KnowledgeBankRetrievalParams(**knowledge_bank_search_parameters)
    }

    return all_knowledge_bank_parameters, per_knowledge_bank_parameters


# Initially 'get_search_kwargs'
def get_retriever_search_kwargs(
    retrieval_parameters: Dict[str, Any], vector_db_type: str, filters: Optional[Dict[str, List[Any]]] = None
) -> Dict[str, Any]:
    retriever_search_kwargs = {}
    search_type = retrieval_parameters["search_type"]
    if vector_db_type != "AZURE_AI_SEARCH":
        retriever_search_kwargs["k"] = retrieval_parameters["k"]
    if search_type == "mmr":
        retriever_search_kwargs.update(
            {
                "fetch_k": retrieval_parameters["mmr_k"],
                "lambda_mult": retrieval_parameters["mmr_diversity"],
            }
        )
    elif search_type == "similarity_score_threshold":
        retriever_search_kwargs.update({"score_threshold": retrieval_parameters["score_threshold"]})
    if filters:
        # Caution! Keys for different vector stores are different.
        # for example Azure AI Search uses 'filters' not 'filter'
        retriever_search_kwargs.update(filters)
    return retriever_search_kwargs


