# -*- coding: utf-8 -*-
import dataiku
import json
import os
import tempfile
import pickle
import faiss
from transformers import AutoProcessor, AutoModel
from project_utils import encode_image, compute_text_embeddings

df = dataiku.Dataset("questions").get_dataframe()
folder = dataiku.Folder("vOjkXoGz")
index_folder = dataiku.Folder("MQNdVKza")

# Get LLM ID from Dataiku custom variables and initialize LLM object
LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
project = dataiku.api_client().get_default_project()
llm = project.get_llm(LLM_ID)

NUM_CHUNKS_PER_MODALITY = 3

# Load pretrained model and processor for generating text embeddings
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384", local_files_only=True)

# Create a temporary directory to store downloaded files
with tempfile.TemporaryDirectory() as tmp_dir:
    # Download index files and content lists from index_folder
    for path in index_folder.list_paths_in_partition():
        tmp_path = os.path.join(tmp_dir, path[1:])  # Adjust temporary path
        with index_folder.get_download_stream(path) as stream:
            with open(tmp_path, "wb") as f:
                f.write(stream.read())  # Write downloaded files to temporary directory

    # Load image and text indexes and their corresponding content lists
    image_index = faiss.read_index(os.path.join(tmp_dir, "image_index.bin"))
    text_index = faiss.read_index(os.path.join(tmp_dir, "text_index.bin"))
    with open(os.path.join(tmp_dir, "image"), "rb") as fp:
        images = pickle.load(fp)  # Load image content list
    with open(os.path.join(tmp_dir, "text"), "rb") as fp:
        texts = pickle.load(fp)  # Load text content list

# Define a function to retrieve chunks based on a given question
def retrieve_chunks(question):
    chunks = []  # Initialize list to store retrieved chunks
    prompt_embedding = compute_text_embeddings(model, processor, question)  # Compute text embeddings for the question
    # Retrieve nearest neighbors for text and image embeddings
    image_index_retrieved = image_index.search(prompt_embedding, NUM_CHUNKS_PER_MODALITY)[1][0]
    text_index_retrieved = text_index.search(prompt_embedding, NUM_CHUNKS_PER_MODALITY)[1][0]
    # Iterate over retrieved indices and append corresponding chunks to the list
    for i in range(NUM_CHUNKS_PER_MODALITY):
        chunks.append({"type": "text", "content": texts[text_index_retrieved[i]]})  # Append text chunk
        chunks.append({"type": "image", "image_url": images[image_index_retrieved[i]]})  # Append image chunk
    return chunks  # Return list of retrieved chunks


def get_messages(chunks_with_metadata, question):
    """
    Build the messages sent to the multimodal LLM.
    
    Args:
    - chunks_with_metadata (list): List of chunks with metadata (type, content, image_url, caption).
    - question (str): The question posed by the user.
    
    Returns:
    - messages (list): List of messages formatted for the multimodal LLM.
    """
    messages = []
    
    # System message: instruction for the multimodal LLM
    messages.append({
        "role": "system",
        "parts": [
            {
                "type": "TEXT",
                "text": "You are a helpful assistant. Concisely answer the user's question based on the provided facts. If you don't know, just say you don't know."
            }
        ]
    })
    
    # User message: the question posed by the user
    messages.append({
        "role": "user",
        "parts": [
            {
                "type": "TEXT",
                "text": f"Answer the following question: {question}. Use the following facts."
            }
        ]
    })
    
    # Iterate over each chunk with metadata
    for chunk in chunks_with_metadata:
        # If the chunk is text, add it to the messages as a fact
        if chunk["type"] == "text":
            messages.append({
                "role": "user",
                "parts": [
                    {
                        "type": "TEXT",
                        "text": f"Fact: {chunk['content']}"
                    }
                ]
            })
        # If the chunk is an image, add its caption (if available) and the image itself to the messages
        else:
            caption = [{
                "type": "TEXT",
                "text": f"Fact: {chunk['caption']}"
            }] if "caption" in chunk else []  # Check if the chunk has a caption
            messages.append({
                "role": "user",
                "parts": caption + [  # Include the caption (if available) and the image
                    {
                        "type": "IMAGE_INLINE",
                        "inlineImage": encode_image(folder, chunk['image_url'])  # Encode the image for inline display
                    }
                ]
            })
    
    # Return the list of messages
    return messages
      
def answer_question(question):
    """
    Answer the question using a multimodal RAG approach.
    
    Args:
    - question (str): The question to be answered.
    
    Returns:
    - resp_text (str): The generated answer from the multimodal LLM.
    - sources_json (str): JSON representation of the chunks with metadata used to generate the answer.
    """
    # Retrieve chunks with metadata based on the question
    chunks_with_metadata = retrieve_chunks(question)
    
    # Create a new completion for the multimodal LLM
    completion = llm.new_completion()
    completion.cq["messages"] = get_messages(chunks_with_metadata, question)  # Construct messages for the LLM
    completion.settings["maxOutputTokens"] = 300  # Set maximum output tokens
    completion.settings["temperature"] = 0  # Set temperature for text generation
    resp = completion.execute()  # Execute completion and obtain response
    
    # Return generated answer text and JSON representation of sources
    return resp.text, json.dumps(chunks_with_metadata)

for i in df.index:
    # Answer the question and retrieve chunks with metadata
    answer, sources = answer_question(df.at[i, "question"])
    
    # Update DataFrame with generated answer and sources
    df.at[i, "generated_answer"] = answer
    df.at[i, "chunks"] = json.dumps(sources)

dataiku.Dataset("answers_siglip").write_with_schema(df)
