# -*- coding: utf-8 -*-
import dataiku
import io
import pandas as pd
from utils import load_image
import torch
from transformers import AutoProcessor, UdopForConditionalGeneration
from peft import LoraConfig, get_peft_model

# Read input folders:
# 1. Folder containing the fine-tuned LoRA weights for the UDOP model
model_UDOP_FT_PEFT = dataiku.Folder("iHCLMsKi")

# 2. Folder containing the images to classify and the ground truth labels
folder_images = dataiku.Folder("Rvdjm86o")

# Get the list of image files from the folder and remove the 'label.json' file (which contains ground truth)
files = folder_images.list_paths_in_partition()
files.remove("/label.json")

# Load the ground truth labels from 'label.json'
ground_truth = folder_images.read_json("/label.json")

# Load the fine-tuned LoRA weights from the folder
with model_UDOP_FT_PEFT.get_download_stream(
    "UDOP_weights_LoRA_1024.pt"
) as folder_stream:
    # Read the LoRA weights into a buffer
    buffer = io.BytesIO(folder_stream.read())
    buffer.seek(0)  # Move to the beginning of the buffer
    lora_weights = torch.load(
        buffer, map_location=torch.device("cpu")
    )  # Load the weights to CPU

# Define the base pre-trained UDOP model and processor
model_name = "microsoft/udop-large-512-300k"

# Load the pre-trained processor and the UDOP model with OCR enabled
processor = AutoProcessor.from_pretrained(
    model_name, local_files_only=True, apply_ocr=True
)
model = UdopForConditionalGeneration.from_pretrained(model_name, local_files_only=True)

# Configure LoRA for PEFT (Parameter-Efficient Fine-Tuning)
peft_config = LoraConfig(
    r=12,  # LoRA rank
    lora_alpha=12,  # LoRA scaling factor
    target_modules=[
        "k",
        "q",
        "v",
        "o",
    ],  # Target modules in the transformer for adaptation
)

# Combine the base UDOP model with the LoRA configuration
model = get_peft_model(model, peft_config)

# Apply the LoRA weights to the model
for name, param in model.named_parameters():
    if "lora" in name:
        param.data = lora_weights[name]  # Load the corresponding LoRA weights into the model

# Set the model to evaluation mode
model.eval()

# Load the images for classification
images = [load_image(folder_images, file).convert("RGB") for file in files]
image_name = [file[1:] for file in files]  # Remove leading '/' from file names

# Define the classification prompt
prompt = "document classification."

# Initialize a list to store predictions
predictions = []

# Loop through each image and perform classification
for image in images:
    # Prepare the input for the model (image + prompt)
    encoding = processor(
        text=prompt,
        images=image,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=1024,
    )

    # Generate predictions using the fine-tuned model
    predicted_ids = model.generate(**encoding, max_new_tokens=20)

    # Decode the predicted IDs into human-readable text
    result = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    predictions.append(result)

# Create a DataFrame with file names and predicted categories
df = pd.DataFrame({"File": image_name, "Category": predictions})

# Convert the ground truth labels from the JSON file into a pandas DataFrame
ground_truth_df = pd.DataFrame(list(ground_truth.items()), columns=["File", "Label"])

# Merge the predictions with the ground truth based on the file name
merged_df = pd.merge(df, ground_truth_df, on="File", how="left")

# Write the results to a dataset
dataiku.Dataset("Results_classification_UDOP_FT").write_with_schema(merged_df)
