# -*- coding: utf-8 -*-
import transformers
import pandas as pd
import dataiku
from collections import defaultdict
from utils import load_image
from qwen_vl_utils import process_vision_info

sample_docVQA = dataiku.Folder("JRpzJC6D")
files = sample_docVQA.list_paths_in_partition()
files.remove("/data.json")  # Remove label file from the list of images

# Load the ground truth labels from the 'data.json' file
labels = sample_docVQA.read_json("data.json")

# Define the pre-trained model and load the processor
model_name = "Qwen/Qwen2-VL-7B-Instruct"

# Configure the model for 8-bit quantization to save memory
bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)

# Load the Qwen model and processor for conditional generation
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(
    model_name, quantization_config=bnb_config, device_map="auto"
)
processor = transformers.AutoProcessor.from_pretrained(model_name)

# Initialize lists for storing results
questions = []
predictions = []
filename = []

# Process each image file and its corresponding questions
for file in files:
    for entry in labels["data"]:  # Loop over the questions in the dataset
        if entry["image"].endswith(file):  # Match image with corresponding label entry
            filename.append(file.split("/")[-1])  # Store the file name
            question = entry["question"]
            questions.append(question)  # Store the question for this image

            # Load and preprocess the image
            image = load_image(sample_docVQA, file)
            if image.mode != "RGB":
                image = image.convert("RGB")  # Ensure the image is in RGB format
            image = image.resize((896, 896))  # Resize to avoid CUDA OOM issues

            # Prepare the message for the model (image and question)
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": image},  # Add image content
                        {
                            "type": "text",
                            "text": "Answer the question. Do not write a full sentence, just provide a value. "
                            + question,
                        },  # Add question
                    ],
                }
            ]

            # Preprocess and generate predictions
            text = processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            image_inputs = process_vision_info(messages)[0]  # Prepare vision inputs

            # Prepare inputs for the model and move them to GPU
            inputs = processor(
                text=[text],
                images=image_inputs,
                padding=True,
                return_tensors="pt",
            ).to("cuda")

            # Generate prediction from the model
            generated_ids = model.generate(**inputs, max_new_tokens=30, do_sample=False)
            generated_ids_sequence = generated_ids[0][
                len(inputs["input_ids"][0]) :
            ]  # Extract generated sequence
            generated_text = processor.decode(
                generated_ids_sequence, skip_special_tokens=True
            )

            predictions.append(generated_text)  # Store the model's prediction

# Create a DataFrame for storing predictions
df = pd.DataFrame({"File": filename, "Question": questions, "Answers": predictions})

# Prepare the ground truth DataFrame, removing duplicates
dataframe_entries = defaultdict(list)

for entry in labels["data"]:
    image_name = entry["image"].split("/")[-1]  # Extract the image name
    question = entry["question"]  # Extract the question
    answers = entry["answers"]  # Extract the list of answers
    dataframe_entries[(image_name, question)].extend(
        answers
    )  # Store answers with image and question

# Flatten the dictionary into a DataFrame
ground_truth_df = pd.DataFrame(
    [
        {"File": image_name, "Question": question, "Label": answers}
        for (image_name, question), answers in dataframe_entries.items()
    ]
)

# Merge predictions with ground truth data
merged_df = pd.merge(df, ground_truth_df, on=["File", "Question"], how="left")

# Write the final DataFrame with predictions and ANLS scores to a Dataiku dataset
dataiku.Dataset("results_VQA_QwenVL2").write_with_schema(merged_df)
