from typing import Any, List, Union

from common.backend.models.source import AggregatedToolSources, RAGImage, Source
from common.backend.utils.llm_utils import get_llm_friendly_name
from common.backend.utils.picture_utils import b64encode_image_from_full_folder_id
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLM, DSSLLMCompletionResponse, DSSLLMStreamedCompletionFooter


def _convert_records_to_rows_and_columns(columns: List[str], data: List[Any]):
    rows = [dict(zip(columns, row)) for row in data]

    columns_formatted = [{"name": col, "label": col, "field": col, "align": "left"} for col in columns]

    return {"rows": rows, "columns": columns_formatted}


def _enrich_metadata_with_tags(metadata: dict) -> dict:
    """
    We don't display metadata directly but we display tags which are
    built into metadata.
    """
    tags = []
    for key, value in metadata.items():
        try:
            value_as_str = str(value)

            tags.append({"name": key, "value": value_as_str, "type": "string"})

        except Exception:
            continue

    metadata["tags"] = tags
    return metadata


def _build_file_based_document_source(source: Source) -> Source:
    """
    Generates a frontend-compatible source object based on a DSS source
    of type FILE_BASED_DOCUMENT.

    Note: Certain subtleties apply for frontend display, such as selectively showing data
    based on the presence of images and snippets, or title handling.
    """
    rag_images: List[RAGImage] = []
    metadata = source.get("metadata")
    textSnippet, htmlSnippet, markdownSnippet = "", "", ""

    # if there's some imageRefs we display them instead of snippets
    # displaying both would be too much
    if imageRefs := source.get("imageRefs"):
        for image_ref in imageRefs:
            b64_image = b64encode_image_from_full_folder_id(
                image_ref.get("path") or "", image_ref.get("folderId") or ""
            )
            rag_images.append(
                RAGImage(
                    full_folder_id=image_ref.get("folderId") or "",
                    file_path=image_ref.get("path") or "",
                    index=0,
                    file_data=b64_image,
                )
            )
    else:
        textSnippet = source.get("textSnippet") or ""
        htmlSnippet = source.get("htmlSnippet") or ""
        markdownSnippet = source.get("markdownSnippet") or ""

    title = source.get("title")
    fileRef = source.get("fileRef")
    if isinstance(fileRef, dict):
        page_range = fileRef.get("pageRange") or {}
        page_start = page_range.get("start") or ""
        page_end = page_range.get("end") or ""
        filepath = fileRef.get("path") or ""
        if not title:  # title can be set in the settings (from a metadata) but not always
            title = f"{filepath} (p. {page_start} - {page_end})"

    if metadata is not None and isinstance(metadata, dict):
        metadata = _enrich_metadata_with_tags(metadata)

    source_for_asw = Source(
        type="SIMPLE_DOCUMENT",
        title=title,
        url=source.get("url"),
        textSnippet=textSnippet,
        htmlSnippet=htmlSnippet,
        markdownSnippet=markdownSnippet,
        images=rag_images,
        metadata=metadata,
    )

    return source_for_asw


def filter_sources_items(sources: List[Source], sources_type: List[str]) -> List[Source]:
    """
    In some cases we want to filter the sources.
    - if the agent with sql tool succeed to generate records we don't want to display the
        errors it encounters during its retry mechanism.

    Args:
        sources (List[Source]): The sources list to filter.
        sources_type (List[str]): A list indicating the type of each source, in matching order with sources.

    Returns:
        List[Source]: Filtered source
    """

    if "ERROR" in sources_type and "RECORDS" in sources_type:
        sources.pop(sources_type.index("ERROR"))
        sources_type.pop(sources_type.index("ERROR"))

    return sources


def build_source(input_source: Source) -> Source:
    """
    Take a 'raw' source (e.g.: coming from an agent) and
    builds a source usable by Answers or Agent-connect.
    """
    output_source = Source()
    source_type = input_source.get("type", "")

    if source_type == "FILE_BASED_DOCUMENT":
        output_source = _build_file_based_document_source(input_source)

    elif source_type == "GENERATED_SQL_QUERY":
        output_source = Source(textSnippet=input_source.get("performedQuery"), type=source_type)

    elif source_type == "RECORDS" and "records" in input_source:
        records = input_source.get("records") or {}
        columns = records.get("columns") or []
        data = records.get("data") or []

        output_source = Source(
            records=_convert_records_to_rows_and_columns(columns, data),
            type=source_type,
        )

    else:
        if source_type != "SIMPLE_DOCUMENT":
            logger.debug(f"handling unknown source_type {source_type}")

        output_source = Source(
            title=input_source.get("title"),
            type=input_source.get("type", ""),
            images=input_source.get("images"),
            textSnippet=input_source.get("textSnippet"),
            htmlSnippet=input_source.get("htmlSnippet"),
            markdownSnippet=input_source.get("markdownSnippet"),
            url=input_source.get("url", ""),
            metadata=_enrich_metadata_with_tags(input_source.get("metadata") or {}),
        )

    return output_source


def build_augmented_llm_or_agent_sources(
    response: Union[DSSLLMCompletionResponse, DSSLLMStreamedCompletionFooter], llm_used: DSSLLM
) -> List[AggregatedToolSources]:
    """
    For augmented llm and agents with kb tool the sources are not
    stored in the same way so we need to extract them from the completion response
    """

    llm_id_used = llm_used.llm_id
    logger.debug(f"Building sources from agent or augmented llm : {llm_id_used}")
    llm_used_name = get_llm_friendly_name(llm_used.llm_id, llm_used.project_key)
    completion_response_source: List[AggregatedToolSources] = []

    if isinstance(response, DSSLLMCompletionResponse):
        logger.debug("Building source from DSSLLMCompletionResponse")
        completion_response_source = response._raw.get("additionalInformation", {}).get("sources") or []
    if isinstance(response, DSSLLMStreamedCompletionFooter):
        logger.debug("Building source from DSSLLMStreamedCompletionFooter")
        completion_response_source = response.data.get("additionalInformation", {}).get("sources", []) or []

    if completion_response_source:
        aggregated_sources: List[AggregatedToolSources] = []

        for aggregated_source in completion_response_source:
            items: List[Source] = []
            items_type: List[str] = []

            for source in aggregated_source.get("items") or []:
                items_type.append(source.get("type") or "")
                items.append(build_source(source))

            aggregated_sources.append(
                AggregatedToolSources(
                    toolCallDescription=aggregated_source.get("toolCallDescription") or f"Used: {llm_used_name}",
                    items=filter_sources_items(items, items_type),
                )
            )

        return aggregated_sources
    else:
        logger.warn("Try to extract sources from augmented llm or agent but no sources found.")
        return []
