# -*- coding: utf-8 -*-
import dataiku
import pandas as pd
from transformers import AutoProcessor, UdopForConditionalGeneration
from utils import load_image

folder = dataiku.Folder("Rvdjm86o") # Folder containing JPEG files and ground truth labels

files = folder.list_paths_in_partition()
# Remove the label.json file from the list of images (label.json contains the ground truth labels)
files.remove("/label.json")

# Load the ground truth labels from label.json
ground_truth = folder.read_json("/label.json")

# Define the model name and load the processor and model
model_name = "microsoft/udop-large-512-300k"

# Initialize the processor for the UDOP model with OCR capabilities
processor = AutoProcessor.from_pretrained(
    model_name, local_files_only=True, apply_ocr=True
)

# Initialize the UDOP model for conditional generation
model = UdopForConditionalGeneration.from_pretrained(model_name, local_files_only=True)

# Load images from the folder, converting them to RGB mode
images = [load_image(folder, file).convert("RGB") for file in files]
image_name = [file[1:] for file in files]  # Strip the leading '/' from file paths

# Initialize a list to store predictions
predictions = []

# Define a classification prompt for the model
prompt = "Document classification."

# Process each image and generate predictions
for image in images:
    # Encode the input (text + image) for the model
    encoding = processor(
        text=prompt,
        images=image,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=1024,
    )

    # Generate predictions from the model
    predicted_ids = model.generate(**encoding, max_new_tokens=20)

    # Decode the predicted IDs to get the predicted category
    result = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    predictions.append(result)

# Create a DataFrame with the image file names and predicted categories
df = pd.DataFrame({"File": image_name, "Category": predictions})

# Convert the ground truth labels (from JSON) to a pandas DataFrame
ground_truth_df = pd.DataFrame(list(ground_truth.items()), columns=["File", "Label"])

# Merge the predictions DataFrame with the ground truth DataFrame based on the file name
merged_df = pd.merge(df, ground_truth_df, on="File", how="left")

# Write the merged DataFrame to a Dataiku dataset
dataiku.Dataset("results_classification_UDOP").write_with_schema(merged_df)
