# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
from PIL import Image
import numpy as np
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.image_transforms import center_to_corners_format
from project_utils import save_array, load_image

df = dataiku.Dataset("train").get_dataframe()
folder = dataiku.Folder("aKlMxTsk")
output_folder = dataiku.Folder("U68shXET")

model_name = dataiku.get_custom_variables()["owlvit_model_name"]
processor = OwlViTProcessor.from_pretrained(
    model_name,
    torch_dtype=torch.float16
)
model = OwlViTForObjectDetection.from_pretrained(model_name)
_ = model.eval()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def get_iou(box1, box2):
    """
    Calculate the intersection over union (IoU) between two bounding boxes.
    """
    xmin1, ymin1, xmax1, ymax1 = box1
    xmin2, ymin2, xmax2, ymax2 = box2
    if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
        return 0
    xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3]
    ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3]
    area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter)
    area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
    area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
    return area_inter / (area1 + area2 - area_inter)

def normalize_box(d, w, h):
    """
    Normalize the bounding box with respect to image dimensions.
    """
    result = {key: value for key, value in d.items()}
    x1, y1, x2, y2 = result["bbox"]
    result["bbox"] = x1/w, y1/h, (x1+x2)/w, (y1+y2)/h
    return result

def unnormalize_box(b, w, h):
    """
    Unnormalize the bounding box coordinates with respect to the image dimensions.
    """
    x1, y1, x2, y2 = b
    return [x1*w, y1*h, x2*w, y2*h]

def open_image(path, boxes_string):
    """
    Open and resize an image and normalize the associated bounding boxes.
    """
    image = load_image(folder, path)
    w, h = image.size
    boxes = json.loads(boxes_string)
    boxes = [normalize_box(x, w, h) for x in boxes]
    w, h = 400, h*400//w
    return image.resize((w, h)), boxes

def get_class_embeds(image, boxes):
    """
    Compute the class embeddings corresponding to bounding boxes in an image.
    """
    result = {}

    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        feature_map, _ = model.image_embedder(**inputs)
        batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
        image_feats = torch.reshape(
            feature_map,
            (batch_size, num_patches * num_patches, hidden_dim)
        )
        pred_boxes = model.box_predictor(image_feats, feature_map)
        pred_boxes = center_to_corners_format(pred_boxes)
        _, image_class_embeds = model.class_predictor(image_feats)
        image_class_embeds /= 1e-6 + torch.linalg.norm(
            image_class_embeds,
            dim=-1,
            keepdim=True
        )

    for box in boxes:
        ious = []
        for i in range(pred_boxes.shape[1]):
            ious.append(
                get_iou(
                    np.array(pred_boxes[0, i, :]),
                    box["bbox"]
                )
            )
        order = np.argsort(ious)[::-1]
        for n in range(len(order)):
            if ious[order[n]] < 0.65:
                break

        if n == 1:
            n = 0
        elif n > 1:
            scores = [0]*n
            for i in range(n-1):
                for j in range(i+1, n):
                    cos = torch.dot(image_class_embeds[0, order[i], :], image_class_embeds[0, order[j], :])
                    scores[i] += cos
                    scores[j] += cos
            n = np.argmin(scores)

        class_embeds = image_class_embeds[:, order[n], :].numpy()

        if box["category"] in result:
            result[box["category"]] = np.append(result[box["category"]], class_embeds, axis=0)
        else:
            result[box["category"]] = class_embeds

    return result

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
query_embeds = {}

for i in range(len(df)):
    # Compute the class embeddings for all bounding boxes of all images
    # Group them according to the label of the bounding boxes
    image, boxes = open_image(df.iloc[i].record_id, df.iloc[i].label)
    result = get_class_embeds(image, boxes)
    for k in result:
        if k not in query_embeds:
            query_embeds[k] = result[k]
        else:
            query_embeds[k] = np.append(query_embeds[k], result[k], axis=0)

# Average and normalize the embeddings for each label
for k in query_embeds:
    query_embeds[k] = np.mean(query_embeds[k], axis=0)
    query_embeds[k] /= 1e-6 + np.sqrt(np.sum(query_embeds[k]**2))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for k in query_embeds:
    save_array(output_folder, f"{k}.npy", query_embeds[k])