# -*- coding: utf-8 -*-
import dataiku
import pandas as pd
import re
from utils import encode_image

folder = dataiku.Folder("Rvdjm86o")  # Folder containing images to classify
files = folder.list_paths_in_partition()  # List all file paths in the folder

few_shot_folder = dataiku.Folder("55DQRQlh")  # Few-shot example images

LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)

PROMPT = """Your task is to classify the pictures given with only one of the 16 following categories: 
email, handwritten, advertisement, scientific report, scientific publication, letter, form, specification, 
file folder, news article, budget, invoice, presentation, questionnaire, resume, memo."""

classes_few_shot = few_shot_folder.read_json("/label_fs.json")

def message_few_shots(few_shot_folder, prompt, folder_prediction, file_prediction):
    """
    Constructs the prompt with few-shot examples for GPT-4.

    Args:
        few_shot_folder: Dataiku folder containing few-shot examples.
        prompt: Text prompt describing the classification task.
        folder_prediction: Folder containing the images to classify.
        file_prediction: The specific image file to classify.

    Returns:
        messages: A list of message parts for the GPT-4 model, including few-shot examples and the image to classify.
    """
    files_few_shots = (
        few_shot_folder.list_paths_in_partition()
    )  # Get few-shot example files
    
    files_few_shots.remove("/label_fs.json")

    
    messages = []

    # Add task description as a system message
    messages.append({"role": "system", "parts": [{"type": "TEXT", "text": PROMPT}]})

    # Loop through the few-shot examples and add them to the message
    for file in files_few_shots:

        label_few_shot = classes_few_shot[file[1:]]
        
        # Add the few-shot image and its corresponding label to the message
        messages.append(
            {
                "role": "user",
                "parts": [
                    {
                        "type": "IMAGE_INLINE",
                        "inlineImage": encode_image(few_shot_folder, file),
                    }
                ],
            }
        )
        messages.append(
            {"role": "assistant", "parts": [{"type": "TEXT", "text": label_few_shot}]}
        )

    # Add the image to be classified at the end of the message
    messages.append(
        {
            "role": "user",
            "parts": [
                {
                    "type": "IMAGE_INLINE",
                    "inlineImage": encode_image(folder_prediction, file_prediction),
                }
            ],
        }
    )

    return messages  # Return the complete message for the multimodal LLM


data = {"File": [], "Response": []}

# Loop through the image files for classification
for file in files:
    if file.endswith(".jpg"):  # Only process .jpg files
        # Create the few-shot message for each image
        messages = message_few_shots(few_shot_folder, PROMPT, folder, file)

        # Send the message to the GPT-4 model
        completion = llm.new_completion()
        completion.cq["messages"] = messages
        completion.settings["maxOutputTokens"] = 30  # Set maximum output tokens
        completion.settings["temperature"] = (
            0  # Set temperature for deterministic responses
        )

        # Execute the model and store the response
        resp = completion.execute()
        if resp.success:
            data["File"].append(file[1:])
            data["Response"].append(resp.text)

# Convert the response data into a DataFrame
df = pd.DataFrame(data)

# Load the ground truth labels from a JSON file in the folder
ground_truth = folder.read_json("/label.json")

# Convert the ground truth labels into a DataFrame
ground_truth_df = pd.DataFrame(list(ground_truth.items()), columns=["File", "Label"])

# Merge the model responses with the ground truth labels for comparison
merged_df = pd.merge(df, ground_truth_df, on="File", how="left")

# Write the merged DataFrame to a Dataiku dataset
dataiku.Dataset("results_classification_GPT_few_shots").write_with_schema(merged_df)
