# -*- coding: utf-8 -*-
import dataiku
import pandas as pd
from collections import defaultdict

folder = dataiku.Folder("JRpzJC6D")

files = folder.list_paths_in_partition()
files.remove("/data.json")  # Remove label file from the list of images

labels = folder.read_json("data.json")

LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)

# Initialize list to store LLM responses
responses = []

# Loop through image files and match them with the corresponding question in the labels
for file in files:
    for entry in labels["data"]:
        if entry["image"].endswith(file):  # Ensure that the question matches the image
            question = entry["question"]

            # Create LLM completion request
            completion = llm.new_completion()
            completion.settings["maxOutputTokens"] = 30
            completion.settings["temperature"] = 0

            mp_message = completion.new_multipart_message()
            mp_message.with_text(f"Answer the question. Do not write a full sentence, just provide a value. Question: {question}")
            with folder.get_download_stream(file) as stream:
                mp_message.with_inline_image(stream.read())
            mp_message.add()

            # Execute the LLM and handle the response
            resp = completion.execute()
            if (
                resp.success and resp.text.strip()
            ):  # Ensure response is successful and not empty
                # Store the filename, question, and response in a structured format
                responses.append(
                    {(entry["image"].split("/")[-1], question): resp.text.rstrip(".")}
                )

# Prepare response data for conversion to DataFrame
response_data = {
    "File": [file for response in responses for (file, _) in response.keys()],
    "Question": [
        question for response in responses for (_, question) in response.keys()
    ],
    "Response": [resp for response in responses for resp in response.values()],
}

# Create a DataFrame for the LLM responses
df_responses = pd.DataFrame(response_data)

# Prepare ground truth DataFrame without duplicates using defaultdict
ground_truth_entries = defaultdict(list)

for entry in labels["data"]:
    image_name = entry["image"].split("/")[-1]  # Get the image name
    question = entry["question"]
    # Convert answers to a set to ensure uniqueness
    unique_answers = set(entry["answers"])
    ground_truth_entries[(image_name, question)].extend(unique_answers)

# Flatten the ground truth dictionary into a list of records and convert to DataFrame
ground_truth_data = [
    {"File": image_name, "Question": question, "Label": list(answers)}
    for (image_name, question), answers in ground_truth_entries.items()
]

df_ground_truth = pd.DataFrame(ground_truth_data)

# Merge LLM responses with ground truth data based on File and Question
merged_df = pd.merge(df_responses, df_ground_truth, on=["File", "Question"], how="left")

"""# Calculate ANLS score for each row by comparing the prediction with the ground truth labels
merged_df["Score_ANLS"] = merged_df.apply(
    lambda row: anls_score(
        prediction=row["Response"], gold_labels=row["Label"], threshold=0.5
    ),
    axis=1,
)"""

# Write the final results (responses, ground truth, and ANLS scores) to a Dataiku dataset
dataiku.Dataset("results_VQA_GPT").write_with_schema(merged_df)
