# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import os
import re
from ast import literal_eval
from PIL import Image
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig, AutoModelForVision2Seq

df = dataiku.Dataset("questions_augmented").get_dataframe()
folder = dataiku.Folder("vOjkXoGz")

# Convert the 'chunks' column from string representations of lists to actual lists
df["chunks"] = df['chunks'].apply(literal_eval)

# Initialize lists to store similar images and texts
similar_images, similar_texts = list(), list()
for i in range(len(df)):
    similar_chunk_images, similar_chunk_texts = list(), list()
    chunks = literal_eval(df.loc[i, "chunks"])
    for chunk in chunks:
        if chunk["type"] == "text":
            similar_chunk_texts.append(chunk["content"])
        else:
            similar_chunk_images.append(chunk["image_url"])
    similar_images.append(similar_chunk_images)
    similar_texts.append(similar_chunk_texts)

# Add the similar images and texts to the DataFrame
df["similar_images"] = similar_images
df["similar_texts"] = similar_texts

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Configure model quantization for reduced memory usage
DEVICE = "cuda:0"
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Set the checkpoint name for the pretrained model
checkpoint = "HuggingFaceM4/idefics2-8b"

# Load the pretrained model with the quantization configuration
model = AutoModelForVision2Seq.from_pretrained(
    checkpoint,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

# Load the processor associated with the pretrained model
processor = AutoProcessor.from_pretrained(
    checkpoint,
    do_image_splitting=False
)

def augment_prompt(question, texts, images, folder):
    # Define an instruction for the assistant
    instruction = """"You are a helpful assitant.
    Concisely answer the question of the user based on the facts provided.
    If you don't know, just say you don't know."""

    # Initialize the prompt with the instruction
    prompt = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": instruction},
            ]
        }
    ]
    # Add text facts to the prompt
    for text in texts:
        prompt.append(
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Fact: {}".format(text)},
                ]
            }
        )

    # Initialize a list to store raw images
    images_raw = list()
    # Add image facts to the prompt
    for image in images:
        with folder.get_download_stream(path=image) as stream:
            raw_image = Image.open(stream)
        prompt.append(
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Useful facts are in the following images:"},
                    {"type": "image"},
                ]
            }
        )
        images_raw.append(raw_image)

    # Add the question to the prompt
    prompt.append(
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Answer the following question: {}.".format(question)},
            ]
        }
    )

    return prompt, images_raw

# Define a regular expression pattern to extract the assistant's answer
pattern = r'.*\nAssistant:(.*)'

answers = []

# Iterate over each row in the DataFrame
for i, row in df.iterrows():
    question, texts, images = row["question"], row["similar_texts"], row["similar_images"]
    messages, images_raw = augment_prompt(question, texts, images, folder)

    # Prepare the input prompt using the processor
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    if images_raw:
        inputs = processor(text=prompt, images=images_raw, return_tensors="pt")
    else:
        inputs = processor(text=prompt, return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

    # Generate the model's response
    generated_ids = model.generate(**inputs, max_new_tokens=200)
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

    # Extract the answer from the generated text using the regular expression pattern
    answer = re.search(pattern, generated_texts[0], re.DOTALL).group(1).strip()
    
    answers.append(answer)

# Add the answers to the DataFrame
df["answers"] = answers

dataiku.Dataset("idefics2_answers").write_with_schema(df)