# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import os
import io
from PIL import Image
import numpy as np
from skimage.transform import resize
from project_utils import (
    load_image,
    save_image
)

folder = dataiku.Folder("Tmf77vDr")
mask_folder = dataiku.Folder("1HRv3kle")
output_folder = dataiku.Folder("btc9F5cm")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for path in mask_folder.list_paths_in_partition():
    if path.endswith(".png"):
        image = load_image(folder, path[:-4])
        mask_arr = np.array(load_image(mask_folder, path))
        
        # List unique colors and one-hot encode the mask
        colors = np.unique(mask_arr.reshape((-1, 4)), axis=0)
        colors = colors[np.sum(colors, axis=1) > 0, :]
        preds = np.zeros((colors.shape[0], *mask_arr.shape[:-1]), dtype=int)
        for i in range(colors.shape[0]):
            color = colors[i, :]
            indices = (mask_arr == color).all(axis=-1)
            preds[i, indices] = 1

        # Resize the one-hot encoded mask
        preds = resize(preds, (preds.shape[0], *image.size[-1::-1]))
        preds = preds/np.max(preds)

        # Convert the one-hot encoded mask to an image
        values, indices = np.max(preds, axis=0), np.argmax(preds, axis=0)
        segmentation_map = np.where(values < 0.5, -1, indices)
        rgba_array = np.zeros((*(segmentation_map.shape), 4), dtype=np.uint8)
        for i in range(colors.shape[0]):
            indices = segmentation_map == i
            rgba_array[indices] = colors[i, :]

        # Add the mask on top of the image
        mask = Image.fromarray(rgba_array)
        image.paste(mask, (0, 0), mask)

        # Save the modified image
        save_image(output_folder, path[:-4] + ".jpg", image)