import io
import colorsys
import numpy as np
import torch
from PIL import Image

def load_array(folder, path):
    """
    Load a numpy array from a Dataiku folder
    """
    with folder.get_download_stream(path) as stream:
        arr = np.load(io.BytesIO(stream.read()), allow_pickle=True)
    return arr

def save_array(folder, path, *args, **kwargs):
    """
    Save a numpy array in a Dataiku folder.
    """
    with io.BytesIO() as buf:
        if len(args) + len(kwargs) > 1 or len(kwargs) > 0:
            np.savez(buf, *args, **kwargs)
        else:
            np.save(buf, *args)
        folder.upload_data(path, buf.getvalue())

def load_image(folder, path):
    """
    Load an image from a Dataiku folder.
    """
    with folder.get_download_stream(path) as f:
        buf = io.BytesIO(f.read())
    return Image.open(buf)

def save_image(folder, path, image, image_format="JPEG"):
    """
    Save an image in a Dataiku folder.
    """
    buf = io.BytesIO()
    image.save(buf, image_format)
    folder.upload_data(path, buf.getvalue())

def convert_box(box):
    """
    Convert a bounding box from (x1, y1, width, height) format to (x1, y1, x2, y2) format.
    """
    x1, y1, w, h = box
    return [x1, y1, x1+w, y1+h]
    
def compute_activations(model, processor, image):
    """
    Compute the activations of CLIPSeg on a given image.
    """
    with torch.no_grad():
        vision_outputs = model.clip.vision_model(
            **processor(
                images=image,
                padding=True,
                return_tensors="pt"
            ),
            output_hidden_states=True
        )
        return [vision_outputs[2][i + 1] for i in model.extract_layers]

def get_color_map(labels, saturation=0.7, value=0.7, alpha=0.8):
    """
    Get a category-to-color mapping for several categories.
    """
    color_map = {}
    unique_labels = []
    for label in labels:
        if label not in unique_labels:
            unique_labels.append(label)
    n = len(unique_labels)
    for k in range(n):
        hsv = (k/n, saturation, value)
        color_map[unique_labels[k]] = tuple([int(255*x) for x in (list(colorsys.hsv_to_rgb(*hsv))+[alpha])])
    return color_map