import io
import numpy as np
import torch
import base64
from PIL import Image

def load_image(folder, filepath):
    """
    Load an image from a folder.
    """
    with folder.get_download_stream(filepath) as stream:
        return Image.open(io.BytesIO(stream.read()))

def save_image(folder, filepath, img):
    """
    Save an image in a folder.
    """
    with io.BytesIO() as buf:
        img.save(buf, format="JPEG")
        folder.upload_data(filepath, buf.getvalue())

def encode_image(folder, image_path):
    """
    Encode an image in base 64.
    """
    with folder.get_download_stream(image_path) as stream:
        return base64.b64encode(stream.read()).decode('utf-8')
        
def normalize(x):
    """
    Normalize a 2D array along the second dimension.
    """
    return np.divide(x, np.sqrt(np.sum(x**2, axis=1, keepdims=True)))

def compute_image_embeddings(model, processor, list_of_images):
    """
    Compute the embeddings of a list of images.
    """
    with torch.no_grad():
        return (
            model.get_image_features(
                **processor(images=list_of_images, return_tensors="pt", padding=True)
            )
            .numpy()
        )

def compute_text_embeddings(model, processor, list_of_strings):
    """
    Compute the embeddings of a list of strings.
    """
    inputs = processor(text=list_of_strings, padding="max_length", return_tensors="pt")
    with torch.no_grad():
        return model.get_text_features(**inputs).numpy()
    
def retrieve_chunks(retriever, question, metadata, NUM_CHUNKS):
    """
    Get NUM_CHUNKS chunks corresponding to a query with a LangChain retriever.
    """
    chunks = retriever.get_relevant_documents(question)
    chunks_with_metadata, already_seen, j = [], set(), 0
    while len(chunks_with_metadata) < NUM_CHUNKS and j < len(chunks):
        if chunks[j].metadata["index"] not in already_seen:
            already_seen.add(chunks[j].metadata["index"])
            chunk = metadata[chunks[j].metadata["index"]]
            chunk["content"] = chunks[j].page_content
            if chunk["type"] != "text":
                chunk["image_url"] = chunks[j].metadata["index"]
            chunks_with_metadata.append(chunk)
        j += 1
    return chunks_with_metadata