# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
import numpy as np
from skimage.transform import resize
from PIL import Image, ImageDraw
import torch
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
from project_utils import (
    load_array,
    save_array,
    load_image,
    save_image,
    compute_activations,
    get_color_map,
    convert_box
)

THRESHOLD = 0.5

folder = dataiku.Folder("aKlMxTsk")
embeddings_folder = dataiku.Folder("3gKc7IXi")
output_folder = dataiku.Folder("umYWZZ9z")
df = dataiku.Dataset("test").get_dataframe()

model_name = dataiku.get_custom_variables()["clipseg_model_name"]
model = CLIPSegForImageSegmentation.from_pretrained(model_name).eval()
processor = CLIPSegProcessor.from_pretrained(model_name)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with load_array(embeddings_folder, "class_embeddings.npz") as data:
    class_embeddings = data["class_embeddings"]
    labels = data["labels"]

label2color = get_color_map(labels)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in range(len(df)):
    boxes = json.loads(df.loc[i, "label"])
    path = df.loc[i, "record_id"]
    image = load_image(folder, path)

    # Compute and add segmentation masks
    with torch.no_grad():
        logits = model.decoder(
            compute_activations(model, processor, image),
            torch.Tensor(class_embeddings)
        )[0].numpy()

    if len(labels) == 1:
        logits = np.expand_dims(logits, axis=0)

    preds = 1/(1 + np.exp(-resize(logits, (logits.shape[0], *image.size[::-1]))))
    values, indices = np.max(preds, axis=0), np.argmax(preds, axis=0)
    segmentation_map = np.where(values < THRESHOLD, -1, indices)

    rgba_array = np.zeros((*(segmentation_map.shape), 4), dtype=np.uint8)
    for j in range(len(labels)):
        label = labels[j]
        color = label2color[label]
        indices = segmentation_map == j
        rgba_array[indices] = color
    mask = Image.fromarray(rgba_array)
    image.paste(mask, (0, 0), mask)

    # Add ground truth bounding boxes
    dctx = ImageDraw.Draw(image)
    for j in range(len(boxes)):
        color = label2color[boxes[j]["category"]]
        dctx.rectangle(convert_box(boxes[j]["bbox"]), outline=color)

    save_image(output_folder, path, image)