# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
from PIL import Image
import numpy as np
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from transformers.models.owlvit.modeling_owlvit import OwlViTObjectDetectionOutput
from transformers.image_transforms import center_to_corners_format
from project_utils import load_array, load_image

threshold = 0.95
k = 3

model_name = dataiku.get_custom_variables()["owlvit_model_name"]
processor = OwlViTProcessor.from_pretrained(
    model_name,
    torch_dtype=torch.float16
)
model = OwlViTForObjectDetection.from_pretrained(model_name)
_ = model.eval()

df = dataiku.Dataset("test").get_dataframe()
folder = dataiku.Folder("aKlMxTsk")
embeddings_folder = dataiku.Folder("U68shXET")

def normalize_box(d, W, H):
    """
    Change a box' format ((x1, y1, w, h) --> (x1, y1, x2, y2)) and normalize it.
    """
    result = {key: value for key, value in d.items()}
    x1, y1, w, h = result["bbox"]
    result["bbox"] = x1/W, y1/H, (x1+w)/W, (y1+h)/H
    return result

def unnormalize_box(b, ratio):
    """
    Change a box' format ((x1, y1, x2, y2) --> (x1, y1, w, h)) and unnormalize it.
    """
    x1, y1, x2, y2 = b
    return [x1*ratio, y1*ratio, (x2-x1)*ratio, (y2-y1)*ratio]

def open_image(path, boxes_string):
    """
    Load and resize an image, and normalize its bounding boxes
    """
    image = load_image(folder, path)
    w, h = image.size
    boxes = json.loads(boxes_string)
    boxes = [normalize_box(x, w, h) for x in boxes]
    w, h = 400, h*400//w
    return image.resize((w, h)), boxes, *image.size

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
query_embeds = {}
for path in embeddings_folder.list_paths_in_partition():
    label = ".".join(path.split("/")[1].split(".")[:-1])
    query_embeds[label] = load_array(embeddings_folder, path)

names = [k for k in query_embeds]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE

for i in range(len(df)):
    image, boxes, w, h = open_image(df.loc[i].record_id, df.loc[i].label)
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        # Compute the predicted boxes of the image and the associated box embeddings
        feature_map, _ = model.image_embedder(**inputs)
        batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
        image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
        pred_boxes = model.box_predictor(image_feats, feature_map)
        # Compute the logits assessing how the box embeddings align with the query embeddings
        pred_logits, image_class_embeds = model.class_predictor(
            image_feats,
            query_embeds=torch.Tensor([query_embeds[k] for k in query_embeds])
        )

    outputs = OwlViTObjectDetectionOutput(
        pred_boxes=pred_boxes,
        logits=pred_logits
    )
    results = processor.post_process(outputs=outputs, target_sizes=torch.Tensor([image.size[::-1]]))
    
    # Select the predicted boxes most aligned with the queries
    pred_boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
    max_results = min(k, sum(scores >= threshold))
    idx = np.argsort(-scores.numpy())[:max_results]
    pred_boxes = [[float(y) for y in x] for x in list(pred_boxes[idx].numpy())]
    labels = list(labels[idx].numpy())
    
    df.loc[i, "prediction"] = json.dumps([{
        "bbox": unnormalize_box(pred_boxes[j], w/400),
        "category": names[labels[j]]
    } for j in range(len(labels))])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("test_scored").write_with_schema(df)