# -*- coding: utf-8 -*-
import dataiku
import io
import tempfile
import shutil
import pickle
import transformers
import torch
from torch.utils.data import Dataset
from peft import LoraConfig, get_peft_model
import bitsandbytes as bnb

# Set the input and destination Dataiku folders
folder = dataiku.Folder("BM1RzKhd")  # Folder containing the pre-processed dataset
folder_destination = dataiku.Folder(
    "iHCLMsKi"
)  # Folder where the LoRA weights will be saved

# Set the device to use CUDA if available, otherwise fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Bits and Bytes configuration for loading the model in 8-bit precision (memory-efficient)
bnb_config = transformers.BitsAndBytesConfig(
    load_in_8bit=True,
)

# Load the dataset from the .pt file in the Dataiku folder
with folder.get_download_stream("data_OCR_1024.pt") as folder_stream:
    data = pickle.load(folder_stream)  # Deserialize the data from the .pt file


# Define a custom Dataset class for UDOP document classification
class UDOPDatasetClassification(Dataset):
    """
    Custom PyTorch Dataset for UDOP document classification tasks.

    Args:
        dataset: A dictionary containing input tensors such as input_ids, attention_mask, etc.
    """

    def __init__(self, dataset):
        self.input_ids = dataset["input_ids"]
        self.attention_mask = dataset["attention_mask"]
        self.bbox = dataset["bbox"]
        self.pixel_values = dataset["pixel_values"].to(
            torch.float16
        )  # Convert pixel values to float16 to save memory
        self.labels = dataset["labels"]

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Get a single sample from the dataset.

        Args:
            idx: Index of the sample to retrieve.

        Returns:
            A dictionary containing input tensors (input_ids, attention_mask, bbox, pixel_values, labels).
        """
        sample = {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "bbox": self.bbox[idx],
            "pixel_values": self.pixel_values[idx],
            "labels": self.labels[idx],
        }
        return sample


# Create an instance of the custom dataset
mydataset = UDOPDatasetClassification(data)

# LoRA configuration for parameter-efficient fine-tuning
peft_config = LoraConfig(
    r=8,  # Rank for low-rank adaptation
    lora_alpha=8,  # Scaling factor
    target_modules=["k", "q", "v", "o"],  # Target modules for LoRA (attention layers)
)

# Load the pre-trained UDOP model with 8-bit precision for memory efficiency
model = transformers.UdopForConditionalGeneration.from_pretrained(
    "microsoft/udop-large-512-300k",
    quantization_config=bnb_config,
    local_files_only=True,  # Load from local storage only
)

# Apply LoRA to the UDOP model for fine-tuning
model = get_peft_model(model, peft_config)

# Create a temporary directory to store intermediate training results
output_dir = tempfile.mkdtemp()

# Use 8-bit Adam optimizer from Bits and Bytes for memory-efficient optimization
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=5e-5)

# Define training arguments
training_args = transformers.Seq2SeqTrainingArguments(
    per_device_train_batch_size=1,  # Batch size for each device
    num_train_epochs=1,  # Number of epochs
    output_dir=output_dir,  # Temporary directory to store training outputs
    bf16=True,  # Enable bf16 (bfloat16) training if available
    logging_dir="./logging/",  # Directory to store training logs
    remove_unused_columns=False,  # Avoid removing unused columns
    dataloader_pin_memory=False,  # Disable pinning memory for the DataLoader
)

# Define the Trainer for sequence-to-sequence tasks
trainer = transformers.Seq2SeqTrainer(
    model=model,
    train_dataset=mydataset,  # The custom dataset for training
    args=training_args,
    optimizers=(optimizer, None),  # Optimizer and no scheduler
)

# Train the model
trainer.train()

# Remove the temporary output directory
shutil.rmtree(output_dir)

# Identify and extract only the LoRA weights from the model
lora_parameters = {}
for name, param in model.named_parameters():
    if "lora" in name:
        lora_parameters[name] = param.data  # Save only the LoRA parameters

# Save the LoRA weights into a BytesIO buffer for uploading to the Dataiku folder
with io.BytesIO() as buf:
    torch.save(lora_parameters, buf)  # Save the LoRA parameters to the buffer
    buf.seek(0)  # Move to the beginning of the buffer to read the data
    binary_data = buf.read()  # Read the binary data

# Upload the LoRA weights as a .pt file to the destination Dataiku folder
folder_destination.upload_data("UDOP_weights_LoRA_1024.pt", binary_data)
