import json
from typing import Dict, List

from answers.backend.utils.config_utils import get_retriever_info
from answers.backend.utils.rag_sources import map_rag_sources
from answers.solutions.knowledge_bank import (
    get_core_knowledge_bank,
    get_knowledge_bank_info,
)
from common.backend.models.source import AggregatedToolSources, RAGImage, Source
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger
from langchain.schema.document import Document

try:
    logger.debug("Importing DKU_MULTIMODAL_CONTENT from dataiku.langchain.metadata_generator")
    from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT
except ImportError:
    logger.debug(
        "Importing DKU_MULTIMODAL_CONTENT from dataiku.langchain.metadata_generator failed, falling back to default import"
    )


def generate_sources_from_sample(sample: Dict, tables_used_str: str) -> List[AggregatedToolSources]:
    """
    Based on sample and the tables used it returns a list
    of AggregatedSources in the new format.
    
    Used in the DBRetrievalChain to generate sources
    """
    if not sample:
        return []

    configuration = get_retriever_info(dataiku_api.webapp_config)
    tool_name = configuration.get("name") if configuration else ""

    return [AggregatedToolSources(
        toolCallDescription=f"Used {tool_name}",
        items=[Source(
            type="RECORDS",
            records=sample, # type: ignore
            generatedSqlQuery=sample.get("query", ""),
            usedTables=tables_used_str,
            tool_name_used=tool_name
        )]
    )]


def generate_sources_from_source_documents(source_documents: List[Document]) -> List[AggregatedToolSources]:
    """
    Based on source documents it will build and return a list
    of AggregatedSources in the new format.

    Used in KBRetrievalChain to generate sources
    """
    configuration = get_retriever_info(dataiku_api.webapp_config)
    tool_name = configuration.get("name") if configuration else ""

    kb_id = dataiku_api.webapp_config.get("knowledge_bank_id")
    kb = get_core_knowledge_bank(kb_id)
    kb_info = get_knowledge_bank_info(project=dataiku_api.default_project, knowledge_bank_id=kb_id)
    embedding_recipe_type = kb_info["embedding_recipe_type"]

    sources: List[Source] = []
    for document_index, document in enumerate(source_documents):
        doc_parts = []
        if document.metadata:
            doc_parts = document.metadata.get("parts", [])
        has_parts = len(doc_parts) > 0
        source_content = "" if has_parts else document.page_content
        content_json = json.loads(document.metadata.get(DKU_MULTIMODAL_CONTENT, "{}"))
        source_metadata = document.metadata
        images: List[RAGImage] = []
        if has_parts:
            for part in doc_parts:
                file_path = content_json.get("content") or [] # should return an array of one file_path
                if isinstance(file_path, list): # just to be sure and robust to changes
                    file_path = file_path[0] if file_path else ""
                images.append(
                    RAGImage(
                        full_folder_id=kb._get().get("managedFolderId"),
                        file_path=file_path,
                        index=document_index,
                        file_data=f"data:image/png;base64,{part.inline_image}",
                    )
                )
        sources.append(
            Source(
                textSnippet=source_content, metadata=source_metadata, images=images, tool_name_used=tool_name
            )
        )

    sources = map_rag_sources(sources, embedding_recipe_type)
    formatted_sources = [AggregatedToolSources(toolCallDescription=f"Used {tool_name}", items=sources)]

    return formatted_sources