# -*- coding: utf-8 -*-

# Import necessary libraries and modules
import dataiku
from PIL import Image
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig
import os
import re

# Define the checkpoint for the pre-trained model
checkpoint = "HuggingFaceM4/idefics-9b-instruct"

# Load the dataset containing questions
questions = dataiku.Dataset("questions")
questions_df = questions.get_dataframe()

# Access the folder containing images
images = dataiku.Folder("TVEwE7rl")

# Set the home directory for the Hugging Face model
hf_home_dir = os.getenv("HF_HOME")

# Initialize the AutoProcessor for the model
processor = AutoProcessor.from_pretrained(checkpoint, cache_dir=hf_home_dir)

# Configure quantization to reduce memory footprint
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)

# Determine the device for inference (cuda if available, otherwise cpu)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pre-trained model with specified configurations
model = IdeficsForVisionText2Text.from_pretrained(
    checkpoint, 
    quantization_config=quantization_config,
    device_map="auto",
    cache_dir=hf_home_dir
)

# Define a regex pattern for extracting assistant's responses from generated text
pattern = re.compile(r'\nAssistant:(.*?)(?: \nUser|$)', flags = re.DOTALL)

# Initialize an empty list to store generated answers
answers = list()

# Iterate through each row in the questions dataset
for i, image_question in questions_df.iterrows():
    # Retrieve image name and question
    image_name = image_question["image"]
    question = image_question["question"]
    
    # Open the image using the specified folder path
    with images.get_download_stream(path=image_name) as stream:
        raw_image = Image.open(stream)
        
    # Create prompts for model input
    prompts = [["User:" + question, raw_image , "\nAssistant:"]]
    
    # Process inputs using the AutoProcessor
    inputs = processor(prompts, return_tensors="pt").to(device)
    
    # Configure generation arguments
    bad_words_ids = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
    generated_ids = model.generate(**inputs, bad_words_ids=bad_words_ids, max_length=200)

    # Decode the generated text and extract the first answer using the regex pattern
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
    for i, t in enumerate(generated_text):
        print(f"{i}:\n{t}\n")
    first_answer = pattern.findall(generated_text[0])[0]

    # Append the answer to the list of answers
    answers.append(first_answer)

# Add the answers to the original questions dataframe
questions_df["answer"] = answers

# Write the dataframe with answers to the output dataset
idefics_answers = dataiku.Dataset("idefics_answers")
idefics_answers.write_with_schema(questions_df)