# -*- coding: utf-8 -*-
import dataiku
import numpy as np
from transformers import AutoProcessor, AutoModel
from project_utils import load_image, normalize, compute_image_embeddings, compute_text_embeddings

df = dataiku.Dataset("figures").get_dataframe()
folder = dataiku.Folder("vOjkXoGz")

# Load SIGlip and its processor
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384", local_files_only=True)

# Define the query terms for image types
queries = ["chart", "map", "diagram", "photograph", "illustration"]

queries_embeddings = normalize(compute_text_embeddings(model, processor, queries))

for i in df.index:
    image = load_image(folder, df.at[i, "index"])
    image_embeddings = normalize(compute_image_embeddings(model, processor, [image]))
    dot_product = image_embeddings @ queries_embeddings.T
    # Determine the image type based on the highest dot product value
    df.at[i, "image_type"] = [queries[x] for x in np.argmax(dot_product, axis=1)][0]

# Filter out rows where the image type is 'illustration' or 'photograph'
df = df[~df["image_type"].isin(["illustration", "photograph"])]

dataiku.Dataset("figures_annotated").write_with_schema(df)
