from functools import lru_cache
from typing import Any, Dict, List, Optional, TypedDict, Union

import dataiku
import pandas as pd
from answers.solutions.knowledge_bank import EmbeddingRecipeType
from answers.solutions.vector_search.vector_query import VectorSearch, VectorStoreTypes
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.dataiku_utils import find_recipe
from common.backend.utils.metas_utils import convert_to_list, is_string_list_representation
from common.backend.utils.sql_timing import log_query_time
from common.llm_assist.logging import logger
from pandas.api.types import infer_dtype, is_string_dtype


class KnowledgeBankFilterConfig(TypedDict):
    input_datasource_id: str
    embedding_recipe_type: str
    filter_metadata: List[str]
    filter_options: Dict[str, List[Any]]
    vector_db_type: Optional[str]


def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]


# Disable lists flattening for now, just show what has been configured
manage_lists = False


@log_query_time
def compute_dataset_filter_options(dataset: dataiku.Dataset, columns: List[str]) -> Dict[str, List[Any]]:
    df: pd.DataFrame = dataset.get_dataframe(columns=columns, parse_dates=False, infer_with_pandas=False, use_nullable_integers=True)
    result: Dict[str, List[Any]] = {}
    for column in df.columns:
        if not (is_string_dtype(df[column].dropna())):
            logger.warn(f"The column '{column}' is not of type 'string' (detected type is '{infer_dtype(df[column].dropna())}'). It can't be used to filter metadatas.")
            unique_values = []
        else:
            if manage_lists:
                if df[column].apply(lambda x: isinstance(x, str) and is_string_list_representation(x)).any():
                    df[column] = df[column].apply(convert_to_list)

                if df[column].apply(lambda x: isinstance(x, list)).any():
                    flattened_list = [item for sublist in df[column].dropna().tolist() for item in sublist]
                    unique_values = pd.unique(pd.Series(flattened_list)).tolist()
                else:
                    unique_values = pd.unique(df[column].dropna()).tolist()
            else:
                unique_values = pd.unique(df[column].dropna()).tolist()
        if unique_values:
            result[column] = unique_values
    return result


@log_query_time
@lru_cache(maxsize=None)
def get_knowledge_bank_filtering_settings(knowledge_bank_id: str, with_options: bool = True):
    DISPLAYED_MULTIMODAL_METADATAS = ["source_file", "source_pages"]
    project = dataiku_api.default_project
    config = dataiku_api.webapp_config
    vector_db_type: Optional[str] = None
    if not with_options and knowledge_bank_id:
        vector_db_type = (
            project.get_knowledge_bank(knowledge_bank_id).as_core_knowledge_bank()._get()["vectorStoreType"]
        )

    graph = project.get_flow().get_graph()
    recipe_json = None
    try:
        recipe_json = find_recipe(graph.data, knowledge_bank_id)
    except KeyError:
        raise KeyError(knowledge_bank_id)

    input_datasource_id: str = recipe_json.get("predecessors")[0]
    recipe_name = recipe_json.get("ref")
    embedding_recipe_type = recipe_json.get("subType")
    recipe = project.get_recipe(recipe_name)
    recipe_settings = recipe.get_settings()
    recipe_json_payload = recipe_settings.get_json_payload()
    filter_metadata, metadata_info = [], []
    if embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING:
        metadata_columns_info = recipe_json_payload.get("metadataColumns", [])
        if metadata_columns_info:
            metadata_info = [columns_info["column"] for columns_info in metadata_columns_info]
    elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS:
        metadata_info = DISPLAYED_MULTIMODAL_METADATAS
    else:
        raise Exception(f"The embedding recipe '{recipe_name}' is of type '{embedding_recipe_type}' (Not implemented)!")
    for metadata_name in metadata_info:
        if with_options:
            if metadata_name in list(config.get("knowledge_sources_filters", [])):
                filter_metadata.append(metadata_name)
        else:
            filter_metadata.append(metadata_name)    
                
    filter_options: Dict[str, List[Any]] = {}
    if len(filter_metadata) > 0 and with_options:
        if embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING:
            dataset = dataiku.Dataset(input_datasource_id, project_key=project.project_key)
            filter_options = compute_dataset_filter_options(dataset, columns=filter_metadata)
        elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS:
            # TODO: Implement the logic when metadatas are available in multimodal KBS. Ex:
            #folder = dataiku.Folder(input_datasource_id, project_key=project.project_key)
            #filter_options = compute_folder_filter_options(folder, metadata_list=filter_metadata)
            pass

    return KnowledgeBankFilterConfig(
        input_datasource_id=input_datasource_id,
        embedding_recipe_type=embedding_recipe_type,
        filter_metadata=filter_metadata,
        filter_options=filter_options,
        vector_db_type=vector_db_type,
    )


def get_current_filter_config():
    # TODO: Adapt when several knowledge banks will be connected
    knowledge_bank_id = dataiku_api.webapp_config.get("knowledge_bank_id", None)
    if knowledge_bank_id:
        try:
            result = get_knowledge_bank_filtering_settings(knowledge_bank_id)
            return result
        except KeyError:
            return None
    return None


def filters_to_dsl(filters: Dict[str, List[Any]]) -> Dict[str, Any]:
    if len(filters) < 1:
        raise ValueError("No filters provided")
    elif len(filters) == 1:
        col, values = list(filters.items())[0]
        return {"filter" : {col : {"in": values}}}
    return {"filter" : {"and": [{col : {"in": values}} for col, values in filters.items()]}}


def process_filters_for_db(filters: Union[Dict[str, List[Any]], None], vector_db_type):
    if filters is None:
        return None
    logger.debug(f"processing filters for {vector_db_type}")
    if vector_db_type not in dir(VectorStoreTypes):
        raise ValueError(f"Vector DB type {vector_db_type} is unknown")
    vec_search = VectorSearch(vector_store_type=vector_db_type).get_vector_store()
    processed_filters =  vec_search.invoke(filters_to_dsl(filters))
    return processed_filters
