import json
from typing import Dict, List

from answers.backend.models.base import VectorstoreDocument
from answers.backend.utils.config_utils import get_retriever_info
from answers.backend.utils.rag_sources import map_rag_sources
from answers.solutions.knowledge_bank import (
    get_core_knowledge_bank,
    get_knowledge_bank_info,
)
from common.backend.constants import KEYS_TO_REMOVE_FROM_LOGS
from common.backend.models.source import AggregatedToolSources, RAGImage, Source
from common.backend.services.sources.sources_builder import build_dss_filepath
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.json_utils import mask_keys_in_json
from common.llm_assist.logging import logger

DKU_MULTIMODAL_CONTENT = "DKU_MULTIMODAL_CONTENT" # `DKU_MULTIMODAL_CONTENT` is hard coded to prevent loading heavy langchain dependencies at backend start
logger.debug("`DKU_MULTIMODAL_CONTENT` hard coded value is used")
#try:
#    from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT
#    logger.debug("Importing DKU_MULTIMODAL_CONTENT from dataiku.langchain.metadata_generator")
#    
#except ImportError:
#    logger.debug(
#        "Importing DKU_MULTIMODAL_CONTENT from dataiku.langchain.metadata_generator failed, falling back to default import"
#    )


def generate_sources_from_sample(sample: Dict, tables_used_str: str) -> List[AggregatedToolSources]:
    """
    Based on sample and the tables used it returns a list
    of AggregatedSources in the new format.
    
    Used in the DBRetrievalChain to generate sources
    """
    logger.info("Building sources from sample")
    if not sample:
        return []

    configuration = get_retriever_info(dataiku_api.webapp_config)
    tool_name = configuration.get("name") if configuration else ""

    formatted_sources = [AggregatedToolSources(
        toolCallDescription=f"Used {tool_name}",
        items=[Source(
            type="RECORDS",
            records=sample, # type: ignore
            generatedSqlQuery=sample.get("query", ""),
            usedTables=tables_used_str,
            tool_name_used=tool_name
        )]
    )]

    logger.info(f"Sources built: {formatted_sources}")
    return formatted_sources


def _compute_dss_document_file_url(metadata: Dict, webapp_config: dict) -> str:
    """
    Constructs a preview URL from the document source metadata.
    """
    url = ""
    project = dataiku_api.default_project
    project_key = dataiku_api.default_project_key
    filepath = metadata.get("dku_file_path", "")

    kb_id = webapp_config.get("knowledge_bank_id", "")
    kb_info = get_knowledge_bank_info(project, kb_id)
    odb_id = kb_info.get("documents_folder_id") 

    if project_key and odb_id and filepath:
        # Pass "" as page_range start for now as we don't have access
        # to this information with the results of the vector store
        url = build_dss_filepath(project_key, odb_id, filepath, "")
    return url


def generate_sources_from_source_documents(source_documents: List[VectorstoreDocument]) -> List[AggregatedToolSources]:
    """
    Based on source documents it will build and return a list
    of AggregatedSources in the new format.

    Used in KBRetrievalChain to generate sources
    """
    configuration = get_retriever_info(dataiku_api.webapp_config)
    tool_name = configuration.get("name") if configuration else ""

    url = ""
    kb_id = dataiku_api.webapp_config.get("knowledge_bank_id")
    kb = get_core_knowledge_bank(kb_id)
    kb_info = get_knowledge_bank_info(project=dataiku_api.default_project, knowledge_bank_id=kb_id)
    embedding_recipe_type = kb_info["embedding_recipe_type"]

    sources: List[Source] = []
    for document_index, document in enumerate(source_documents):
        doc_parts = []
        if document.metadata:
            doc_parts = document.metadata.get("parts", [])
            #if dku_document_info := document.metadata.get("DKU_DOCUMENT_INFO", None):
            #    url = _extract_dss_filepath_url(dku_document_info)

            url = _compute_dss_document_file_url(document.metadata, dataiku_api.webapp_config)

        has_parts = len(doc_parts) > 0
        source_content = "" if has_parts else document.page_content
        content_json = json.loads(document.metadata.get(DKU_MULTIMODAL_CONTENT, "{}"))
        source_metadata = document.metadata
        images: List[RAGImage] = []
        if has_parts:
            for part in doc_parts:
                file_path = content_json.get("content") or [] # should return an array of one file_path
                if isinstance(file_path, list): # just to be sure and robust to changes
                    file_path = file_path[0] if file_path else ""
                images.append(
                    RAGImage(
                        full_folder_id=kb._get().get("managedFolderId"),
                        file_path=file_path,
                        index=document_index,
                        file_data=f"data:image/png;base64,{part.inline_image}",
                    )
                )

        sources.append(
            Source(
                textSnippet=source_content,
                metadata=source_metadata,
                images=images,
                tool_name_used=tool_name,
                url=url,
            )
        )

    sources = map_rag_sources(sources, embedding_recipe_type)
    formatted_sources = [AggregatedToolSources(toolCallDescription=f"Used {tool_name}", items=sources)]
    logger.info(f"Sources built: {mask_keys_in_json(formatted_sources, set(KEYS_TO_REMOVE_FROM_LOGS))}")

    return formatted_sources