# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import os
import tempfile
from byaldi import RAGMultiModalModel

df = dataiku.Dataset("questions").get_dataframe()
index_folder = dataiku.Folder("8jv9ZTHM")

LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
project = dataiku.api_client().get_default_project()
llm = project.get_llm(LLM_ID)

NUM_CHUNKS = 3

index_name = "docs"

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with tempfile.TemporaryDirectory() as temp_dir:
    index_directory = os.path.join(temp_dir, index_name)
    os.mkdir(index_directory)
    for f in index_folder.list_paths_in_partition():
        path = f.split("/")
        if len(path) == 3:
            target_directory = os.path.join(index_directory, path[1])
            try:
                os.mkdir(target_directory)
            except FileExistsError:
                pass
        with index_folder.get_download_stream(f) as stream:
            filepath = os.path.join(index_directory, *path[1:])
            with open(filepath, "wb") as f2:
                f2.write(stream.read())
    
    # Load the index
    RAG = RAGMultiModalModel.from_index(index_name, index_root=temp_dir)

id2filename = RAG.get_doc_ids_to_file_names()
id2filename = {k: id2filename[k].split("/")[-1] for k in id2filename}

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def answer_question(question):
    """
    Answer the question with a multimodal RAG approach based on ColPali.
    """
    images = RAG.search(question, k=NUM_CHUNKS)
    
    completion = llm.new_completion()
    
    # Include the user question and the retrieved images in the LLM conversation
    completion.with_message(f"Concisely answer the following question: {question}. Use the documents attached below.")
    sources = []
    for image in images:
        filename = id2filename[image["doc_id"]]
        sources.append(f'{filename} (page {image["page_num"]})')
        mp_message = completion.new_multipart_message()
        mp_message.with_text(sources[-1])
        mp_message.with_inline_image(image["base64"])
        mp_message.add()
    
    completion.settings["maxOutputTokens"] = 300
    completion.settings["temperature"] = 0
    return completion.execute().text, ", ".join(sources)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in df.index:
    answer, sources = answer_question(df.at[i, "question"])
    df.at[i, "generated_answer"] = answer
    df.at[i, "chunks"] = sources

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers_colpali").write_with_schema(df)