from typing import Any, Dict, List, Optional, Union, cast

from common.backend.models.base import RetrieverMode
from common.backend.models.source import AggregatedToolSources, Source
from common.backend.schemas.common_schemas import AggregatedToolSourcesSchema
from common.backend.utils.rag_utils import (
    rm_image_data_from_sources_to_store,
)
from common.llm_assist.logging import logger


def sources_are_old_format(sources: Union[List[Dict], List[Source], List[AggregatedToolSources]]) -> bool:
    keys_to_check = ["toolCallDescription", "items"]
    for source in sources:
        if any(key in source for key in keys_to_check):
            return False
    return True


def format_sources_if_required(sources: Union[List[Dict], List[Source], List[AggregatedToolSources]]) -> List[AggregatedToolSources]:
    """
    Take a source list, checks its format and enrich it to match
    agents-connect formatting if necessary
    """
    if sources:
        if sources_are_old_format(sources):  
            return format_sources(cast(List[Source], sources)) # cast for type checker, will not affect runtime
    return cast(List[AggregatedToolSources], sources) # cast for type checker, will not affect runtime


def get_retrieval_mode(source: Source) -> RetrieverMode:
    """
    Return the source type, old or new format doesn't matter.
    """

    # old format
    keys = source.keys()
    if "sample" in keys:
        return RetrieverMode.DB
    elif "excerpt" in keys:
        return RetrieverMode.KB

    # new format
    source_type = source.get("type", "")
    if source_type == "RECORDS":
        return RetrieverMode.DB
    elif source_type == "SIMPLE_DOCUMENT":
        return RetrieverMode.KB

    logger.warn(f"source type : {source_type} not supported.")
    return RetrieverMode.NO_RETRIEVER


def format_sources(sources: List[Source]) -> List[AggregatedToolSources]:
    """
    Wraps a given list of sources into a new dataset compatible with
    agents-connect format.
    """

    if not sources:
        return []

    # tool_type and tool_name should be the same for every source in the list
    # so it doesn't really matter in which source we get them
    items = []
    tool_type: Optional[str] = ""
    tool_name: Optional[str] = ""
    retrieval_mode = None

    for source in sources : 
        if source:
            if not retrieval_mode:
                retrieval_mode = get_retrieval_mode(source)

            if retrieval_mode == RetrieverMode.KB:
                if not tool_type: 
                    tool_type = "knowledge bank"

                if not tool_name:
                    tool_name = source.get("tool_name_used", "")

                metadata = source.get("metadata", {}) if isinstance(source.get("metadata", {}), dict) else {} # handle the case where metadata = None
                items.append(dict(
                    type="SIMPLE_DOCUMENT",
                    metadata=metadata,
                    title=metadata.get("source_title", ""), # type: ignore
                    url=metadata.get("source_url", ""), # type: ignore
                    textSnippet=source.get("excerpt"),
                    images=source.get("images"),
                    tool_name_used=source.get("tool_name_used", ""),
                ))
            elif retrieval_mode == RetrieverMode.DB:
                if not tool_type:
                    tool_type = "database"

                if not tool_name:
                    tool_name = source.get("tool_name_used", "")

                metadata = source.get("metadata", {})
                sample = source.get("sample", {})
                if sample or metadata :
                    items.append(dict(
                        type="RECORDS",
                        records=sample,
                        generatedSqlQuery=sample.get("query", "") if sample else "",
                        usedTables=metadata.get("source_title", "") if metadata else "", 
                        tool_name_used=source.get("tool_name_used", ""),
                    ))
                else:
                    items.append(source) # type: ignore

    return ([{
        # tool_name from sources is generated in the kb and db retrieval chains
        # through a call to get_retriever_info function
        "toolCallDescription": f"Used {tool_type} {tool_name}",
        "items": items # type: ignore
    }])


def serialize_aggregated_sources_for_api(aggregated_sources: List[AggregatedToolSources]) -> List[Dict[str, Any]]:
    if not aggregated_sources:
        return []

    schema = AggregatedToolSourcesSchema()
    serialized_aggregated_sources: List[Dict[str, Any]] = []

    for agg_source in aggregated_sources:
        serialized_aggregated_sources.append(schema.dump(agg_source))
    
    return serialized_aggregated_sources


def format_sources_to_store(aggregated_sources: List[AggregatedToolSources]) -> List[AggregatedToolSources]:
    """
    For now only remove image data from sources while keeping image infos
    """

    aggregated_sources_to_store = []

    for aggregatedSource in aggregated_sources:
        aggregated_sources_to_store.append(AggregatedToolSources(
            toolCallDescription=aggregatedSource.get("toolCallDescription", ""),
            items=rm_image_data_from_sources_to_store(aggregatedSource.get("items") or [])
        ))

    return aggregated_sources_to_store