import transformers
import pandas as pd
import dataiku
from utils import load_image
from qwen_vl_utils import process_vision_info

folder = dataiku.Folder("Rvdjm86o")

files = folder.list_paths_in_partition()
# Remove the label.json file from the list of images
files.remove("/label.json")

# Load ground truth labels from the label.json file
ground_truth = folder.read_json("/label.json")

model_name = "Qwen/Qwen2-VL-7B-Instruct"

# Configure quantization for reduced memory usage (8-bit quantization)
bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)

# Load the processor and the model using AutoProcessor and Qwen model from Hugging Face
processor = transformers.AutoProcessor.from_pretrained(model_name)
model = transformers.Qwen2VLForConditionalGeneration.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto", quantization_config=bnb_config
)

# Initialize lists to store results
image_names = []
predictions = []

# Define the classification prompt to guide the model's response
prompt = """Classify this picture with only one of the 16 following categories: email, handwritten, advertisement, scientific report, scientific publication, letter, form, specification, file folder, news article, budget, invoice, presentation, questionnaire, resume, memo."""

# Preprocess images and prepare messages for inference
for filename in files:
    # Prepare the image name by stripping the leading '/'
    image_name = filename[1:]
    image_names.append(image_name)

    # Create a message with the image and prompt for the LLM to process
    message = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": load_image(folder, filename),
                },
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        }
    ]

    # Prepare the inputs for the processor using the chat template
    text_inputs = processor.apply_chat_template(
        message, tokenize=False, add_generation_prompt=True
    )

    # Process the image inputs using the utility function to prepare the model-ready inputs
    image_inputs = process_vision_info(message)[0]

    # Prepare the model inputs by combining text and image inputs and converting them to PyTorch tensors
    inputs = processor(
        text=text_inputs,
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    # Generate predictions using the model with a limit of 10 tokens and without sampling (deterministic output)
    generated_ids = model.generate(**inputs, max_new_tokens=10, do_sample=False)

    # Extract the generated token sequence corresponding to the prediction (excluding input tokens)
    generated_ids_sequence = generated_ids[0][len(inputs["input_ids"][0]) :]

    # Decode the generated token sequence into readable text
    generated_text = processor.decode(generated_ids_sequence, skip_special_tokens=True)

    # Append the generated text (predicted category) to the predictions list
    predictions.append(generated_text)

# Create a DataFrame with file names and predicted categories
df = pd.DataFrame({"File": image_names, "Category": predictions})

# Process ground truth labels by converting the JSON file into a DataFrame
ground_truth_df = pd.DataFrame(list(ground_truth.items()), columns=["File", "Label"])

# Merge the predicted categories with the ground truth labels based on the file name
merged_df = pd.merge(df, ground_truth_df, on="File", how="left")

# Write the merged DataFrame (with predictions and ground truth) to a Dataiku dataset
dataiku.Dataset("results_classification_QwenVL2").write_with_schema(merged_df)
