from copy import deepcopy
from functools import lru_cache
from typing import Any, Dict, List, Set, Union

import dataiku
import pandas as pd
from answers.backend.models.base import (
    DISPLAY_ONLY_MULTIMODAL_METADATAS,
    METADATA_REPLACEMENTS,
    SOURCE_FILE_METADATA,
    EmbeddingRecipeType,
    KnowledgeBankFilterConfig,
    VectorStoreSupportingMetadataFetch,
    VectorStoreType,
)
from answers.solutions.knowledge_bank import get_knowledge_bank_info
from answers.solutions.vector_search.vector_query import VectorSearch
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.metas_utils import convert_to_list, is_string_list_representation, is_valid_enum_value
from common.backend.utils.sql_timing import log_query_time
from common.llm_assist.logging import logger
from dataikuapi.dss.project import DSSProject
from pandas.api.types import infer_dtype, is_string_dtype


def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]


@log_query_time
def compute_dataset_filter_options(dataset: dataiku.Dataset, columns: List[str]) -> Dict[str, List[Any]]:
    # Disable lists flattening for now, just show what has been configured
    MANAGE_LISTS = False
    df: pd.DataFrame = dataset.get_dataframe(columns=columns, parse_dates=False, infer_with_pandas=False, use_nullable_integers=True)
    result: Dict[str, List[Any]] = {}
    for column in df.columns:
        if not (is_string_dtype(df[column].dropna())):
            logger.warn(f"The column '{column}' is not of type 'string' (detected type is '{infer_dtype(df[column].dropna())}'). It can't be used to filter metadatas.")
            unique_values = []
        else:
            if MANAGE_LISTS:
                if df[column].apply(lambda x: isinstance(x, str) and is_string_list_representation(x)).any():
                    df[column] = df[column].apply(convert_to_list)

                if df[column].apply(lambda x: isinstance(x, list)).any():
                    flattened_list = [item for sublist in df[column].dropna().tolist() for item in sublist]
                    unique_values = pd.unique(pd.Series(flattened_list)).tolist()
                else:
                    unique_values = pd.unique(df[column].dropna()).tolist()
            else:
                unique_values = pd.unique(df[column].dropna()).tolist()
        if unique_values:
            result[column] = unique_values
    return result


@log_query_time
def compute_knowledge_bank_filter_options(project: DSSProject, knowledge_bank_id: str, vector_store_type: str, metadata_list: List[str]) -> Dict[str, List[Any]]:
    EMPTY_OPTIONS: Dict[str, List[Any]] = {metadata: [] for metadata in metadata_list}
    if not is_valid_enum_value(vector_store_type, VectorStoreSupportingMetadataFetch):
        return EMPTY_OPTIONS
    
    # Metadata to process is the final set of metadatas to use in this process, as 'metadata_list' is the list of metadatas
    #    selected by the user and can contain ux-frienly labels such as 'source_file' which is stored as 'dku_file_path':
    #    hence the use of `METADATA_REPLACEMENTS`
    metadata_to_process = []
    for metadata in metadata_list:
        if name_replacement := METADATA_REPLACEMENTS.get(metadata):
            metadata_to_process.append(name_replacement)
        else:
            metadata_to_process.append(metadata)
    try:
        langchain_vector_store = project.get_knowledge_bank(knowledge_bank_id).as_core_knowledge_bank().as_langchain_vectorstore()
        filter_options: Dict[str, Set[Any]] = {metadata: set() for metadata in metadata_to_process}
        if vector_store_type == VectorStoreType.AZURE_AI_SEARCH.value:
            results_iterator = langchain_vector_store.client.search(
                search_text="*",
                select=metadata_to_process, 
                include_total_count=True
                )
            for azure_metadata in results_iterator:
                for metadata in metadata_to_process:
                    if value := azure_metadata.get(metadata):
                        filter_options[metadata].add(value)
        elif vector_store_type == VectorStoreType.CHROMA.value:
            chroma_metadatas_info = langchain_vector_store.get(include=["metadatas"])
            chroma_metadatas = chroma_metadatas_info["metadatas"]
            for doc_metadatas in chroma_metadatas:
                for metadata in metadata_to_process:
                    if value := doc_metadatas.get(metadata):
                        filter_options[metadata].add(value)

    except Exception as e:
        logger.debug(f"Exception occured when fetching the metadata values (metadatas={metadata_list},  knowledge_bank_id='{knowledge_bank_id}') : '{e}'")
        return EMPTY_OPTIONS

    for metadata in metadata_to_process:
        filter_options[metadata] = list(filter_options[metadata]) # type: ignore

    return filter_options # type: ignore


@log_query_time
@lru_cache(maxsize=None)
def get_knowledge_bank_filtering_settings(knowledge_bank_id: str, with_options: bool = True):
    # 'source_file' will be associated with the 'dku_file_path' metadata ('source_file' being better from a UX readability perspective)
    project = dataiku_api.default_project
    config = dataiku_api.webapp_config
    
    # Knowledge Bank information:
    kb_info = get_knowledge_bank_info(project, knowledge_bank_id)
    embedding_recipe_type = kb_info["embedding_recipe_type"]
    vector_store_type = kb_info["vector_store_type"]
    embedding_recipe_name = kb_info["embedding_recipe_name"]
    metadata_dataset_id = kb_info["metadata_dataset_id"]
    documents_folder_id = kb_info["documents_folder_id"]

    # 'candidate_filter_metadata' is list of metadata potentially available to apply filters:
    candidate_filter_metadata: List[str] = []

    # 'filter_metadata' is the final list of metadata available to apply filters:
    filter_metadata: List[str] = []

    #  'info_metadata' is the final list of metadata available only for information purposes (to 'include in the context of the LLM' or to 'display along with sources'):
    info_metadata: List[str] = []
    
    if embedding_recipe_type not in [EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING.value, EmbeddingRecipeType.EMBED_DOCUMENTS.value]:
        raise Exception(f"The embedding recipe '{embedding_recipe_name}' is of type '{embedding_recipe_type}' (Not implemented)!")
    else:
        recipe_json_payload = project.get_recipe(embedding_recipe_name).get_settings().get_json_payload()
        if embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING.value:
            payload_metadata_key = "metadataColumns"
            
        elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS.value:
            info_metadata = deepcopy(DISPLAY_ONLY_MULTIMODAL_METADATAS)
            if is_valid_enum_value(vector_store_type, VectorStoreSupportingMetadataFetch):
                # 'source_file' (in the end, 'dku_file_path') is added only if the vectorstore has methods to fetch all the indexed metadatas
                candidate_filter_metadata.append(SOURCE_FILE_METADATA)
            payload_metadata_key = "userDefinedMetadataColumns"
        
        for column_info in recipe_json_payload.get(payload_metadata_key, []):
            candidate_filter_metadata.append(column_info["column"])
        
        for metadata_name in candidate_filter_metadata:
            info_metadata.append(metadata_name)
            if with_options:
                if metadata_name in list(config.get("knowledge_sources_filters", [])):
                    if (embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING.value):
                        filter_metadata.append(metadata_name)
                    elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS.value and (metadata_name not in DISPLAY_ONLY_MULTIMODAL_METADATAS):
                        filter_metadata.append(metadata_name)
            else:
                filter_metadata.append(metadata_name)    

    filter_options: Dict[str, List[Any]] = {}
    if len(filter_metadata) > 0 and with_options:
        if embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING.value and metadata_dataset_id:
            filter_options = compute_dataset_filter_options(dataiku.Dataset(metadata_dataset_id), columns=filter_metadata)
        elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS.value:
            filter_options = compute_knowledge_bank_filter_options(project, knowledge_bank_id, vector_store_type, filter_metadata)

    return KnowledgeBankFilterConfig(
        metadata_dataset_id=metadata_dataset_id,
        embedding_recipe_type=embedding_recipe_type,
        documents_folder_id=documents_folder_id,
        info_metadata=info_metadata,
        filter_metadata=filter_metadata,
        filter_options=filter_options,
        vector_db_type=vector_store_type,
    )


def filters_to_dsl(filters: Dict[str, List[Any]]) -> Dict[str, Any]:
    if len(filters) < 1:
        raise ValueError("No filters provided")
    elif len(filters) == 1:
        col, values = list(filters.items())[0]
        return {"filter" : {col : {"in": values}}}
    return {"filter" : {"and": [{col : {"in": values}} for col, values in filters.items()]}}


def process_filters_for_db(filters: Union[Dict[str, List[Any]], None], vector_db_type: str):
    if filters is None:
        return None
    logger.debug(f"processing filters for {vector_db_type}")
    if not is_valid_enum_value(vector_db_type, VectorStoreType):
        raise ValueError(f"Vector DB type {vector_db_type} is unknown")
    vec_search = VectorSearch(vector_store_type=vector_db_type).get_vector_store()
    processed_filters =  vec_search.invoke(filters_to_dsl(filters))
    return processed_filters