import json
import pathlib
from copy import deepcopy
from typing import Optional

from answers.solutions.knowledge_bank import EmbeddingRecipeType
from common.backend.models.source import Source
from common.backend.services.sources.sources_formatter import sources_are_old_format
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.metas_utils import convert_to_list, is_string_list_representation
from common.llm_assist.logging import logger

# move to sources folder ?

def map_rag_sources(sources, embedding_recipe_type):
    def extract_filename(path: str) -> str:
        return pathlib.Path(path).name
    
    new_sources = []
    knowledge_source_title = dataiku_api.webapp_config.get("knowledge_source_title", "")
    knowledge_source_url = dataiku_api.webapp_config.get("knowledge_source_url", "")
    knowledge_source_thumbnail = dataiku_api.webapp_config.get("knowledge_source_thumbnail", "")
    knowledge_sources_displayed_metas = dataiku_api.webapp_config.get("knowledge_sources_displayed_metas", [])

    for source in sources:
        source_metadata = source.get("metadata", {})
        if embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS:
            source_metadata = json.loads(source_metadata.get("DKU_DOCUMENT_INFO", {}))
            if not source_metadata:
                logger.warn("The key 'DKU_DOCUMENT_INFO' was not found in 'source_metadata'")


        # Initialize default values for title, url, and thumbnail
        source_title = ""
        source_url = ""
        source_thumbnail_url = ""

        if knowledge_source_title in source_metadata:
            source_title = source_metadata[knowledge_source_title]

        if knowledge_source_url in source_metadata:
            source_url = source_metadata[knowledge_source_url]

        if knowledge_source_thumbnail in source_metadata:
            source_thumbnail_url = source_metadata[knowledge_source_thumbnail]

        new_source = Source(
            type="SIMPLE_DOCUMENT",
            textSnippet=source["textSnippet"],
            metadata={
                "source_title": source_title,
                "source_url": source_url,
                "source_thumbnail_url": source_thumbnail_url,
                "tags": [],
            },
            title=source_title,
            url=source_url,
            images=source["images"],
            tool_name_used=source["tool_name_used"]
        )

        for meta in knowledge_sources_displayed_metas:
            if embedding_recipe_type == EmbeddingRecipeType.NLP_LLM_RAG_EMBEDDING:
                value = source_metadata.get(meta, None)
            elif embedding_recipe_type == EmbeddingRecipeType.EMBED_DOCUMENTS:
                if meta == "source_pages":
                    value = source_metadata.get("page_range", None) # the 'page_range' key allows retrieve the pages metadata
                    if value is not None:
                        page_start = value.get("start", None)
                        page_end = value.get("end", None)
                        if (page_start is not None) and (page_end is not None):
                            if page_start == page_end:
                                value = f"p. {page_start}"
                            else:
                                value = f"pp. {page_start}-{page_end}"

                elif meta == "source_file":
                    value = source_metadata.get("source_file", {}).get("path", None) # The file name is computed from its 'path'
                    value = extract_filename(value)
                else:
                    value = None
                    logger.warn(f"Displaying the meta '{meta}' is not implemented")

            if value is not None:
                if is_string_list_representation(value):
                    value = convert_to_list(value)
                    value_type = "list"
                else:
                    value_type = "string"
                if new_source["metadata"] and isinstance(new_source["metadata"]["tags"], list): # Added for mypy
                    new_source["metadata"]["tags"].append({"name": meta, "value": value, "type": value_type})

        new_sources.append(new_source)

    return new_sources


def filter_chat_logs_rag_sources(
    initial_sources: list, n_sources_to_keep: int = -1, varchar_limit: Optional[int] = None
):
    """Filters the sources passed to the Chat logs dataset, in order to prevent database varachar limit issues and/or optimize storage.

    :param initial_sources: list: The initial soucres passed to the function.
    :param n_sources_to_keep: int: The number of sources to filter from the initial sources. Set '-1' to keep all the sources.
    :param varchar_limit: int: The varchar limit of the connection used in the context of the webapp.

    :returns: final_sources: list: The sources filtered from the initial_sources.
    :returns: sources_has_been_filtered: bool: Precises whether some sources has been filtered or not
    """

    def get_sources_column_size_in_chat_logs_dataset(sources_to_format):
        sources_column_size = len(str({"sources": sources_to_format}))
        return sources_column_size

    old_format = sources_are_old_format(initial_sources)
    if initial_sources is not None:
        if old_format:
            n_initial_sources = len(initial_sources)
        else:
            n_initial_sources = sum([len(aggregated_tool_source["items"]) for aggregated_tool_source in initial_sources])
    else:
        n_initial_sources = 0
        initial_sources = []

    if n_sources_to_keep == -1:
        final_sources = initial_sources
        n_final_sources = n_initial_sources
    else:
        if old_format:
            final_sources = [source for loop_index, source in enumerate(initial_sources) if loop_index < n_sources_to_keep]
            n_final_sources = len(final_sources)
        else:
            n_final_sources = 0
            final_sources = deepcopy(initial_sources)
            for index, aggregated_tool_source in enumerate(final_sources):
                filtered_source_items = []
                source_items = aggregated_tool_source.get("items", [])
                for source_index, source in enumerate(source_items):
                    if source_index < n_sources_to_keep:
                        filtered_source_items.append(source)
                        n_final_sources += 1
                final_sources[index]["items"] = filtered_source_items

    if isinstance(varchar_limit, int):
        sources_column_size = get_sources_column_size_in_chat_logs_dataset(final_sources)
        if sources_column_size >= varchar_limit:
            final_sources_respecting_characters_limit: list = []
            for source in final_sources:
                sources_column_size = get_sources_column_size_in_chat_logs_dataset(
                    final_sources_respecting_characters_limit + [source]
                )
                if sources_column_size < varchar_limit:
                    final_sources_respecting_characters_limit.append(source)
            final_sources = final_sources_respecting_characters_limit

    
    n_sources_filtered = n_initial_sources - n_final_sources
    sources_has_been_filtered = n_sources_filtered > 0
    if sources_has_been_filtered:
        logger.debug(f"{n_sources_filtered}/{n_initial_sources} sources has been filtered for the chat logs dataset.")
    else:
        logger.debug("No sources has been filtered for the chat logs dataset.")

    return final_sources, sources_has_been_filtered
