# -*- coding: utf-8 -*-
import dataiku
import os
import tempfile
import pickle
from transformers import AutoProcessor, AutoModel
from langchain_text_splitters import RecursiveCharacterTextSplitter
import faiss
from project_utils import load_image, compute_image_embeddings, compute_text_embeddings, normalize

# Load SIGlip and its processor
model = AutoModel.from_pretrained("google/siglip-so400m-patch14-384")
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384", local_files_only=True)

# Initialize a text splitter for chunking text into smaller parts
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
    processor.tokenizer,
    chunk_size=64,
    chunk_overlap=13,
    separators=[" "]
)

# Initialize Dataiku folders for input and output
folder = dataiku.Folder("vOjkXoGz")
output_folder = dataiku.Folder("MQNdVKza")

# Get the DataFrame containing text data from the 'texts' dataset
df = dataiku.Dataset("texts").get_dataframe()

# Initialize lists to store multimodal content, image content, and text content
multimodal_content, image_content, text_content = [], [], []

# Initialize FAISS indexes for multimodal, image, and text embeddings
index_image = faiss.IndexFlatL2(1152)
index_text = faiss.IndexFlatL2(1152)

# Create a dictionary to store indexes and corresponding content lists
index_dict = {
    "image": {
        "index": index_image,
        "list": image_content
    },
    "text": {
        "index": index_text,
        "list": text_content
    }
}

# Iterate over image paths in the folder
for img_path in folder.list_paths_in_partition():
    # Load image and compute image embeddings
    img = load_image(folder, img_path)
    embedding = normalize(compute_image_embeddings(model, processor, img))
    # Add image embeddings to the image index
    index_image.add(embedding)
    # Append image path to the content list
    image_content.append(img_path)

# Iterate over rows in the DataFrame
for i in df.index:
    # Extract text from DataFrame
    text = df.at[i, "text"]
    # Check if text is not NaN and longer than 10 characters
    if text == text and len(text) > 10:
        # Split text into smaller parts and compute embeddings for each part
        for part in splitter.split_text(text):
            embedding = normalize(compute_text_embeddings(model, processor, part))
            # Add text embeddings to the text index
            index_text.add(embedding)
            # Append text to the content list
            text_content.append(text)

# Iterate over each modal and its corresponding value in the index dictionary
for modal, value in index_dict.items():
    # Create a temporary directory to store index and list files
    with tempfile.TemporaryDirectory() as tmp_dir:
        # Save the index
        local_file = str(modal) + "_index.bin"  # Name the index file
        local_file_path = os.path.join(tmp_dir, local_file)  # Path to save the index file locally
        faiss.write_index(value["index"], local_file_path)  # Write the index to the local file
        output_folder.upload_file(local_file, local_file_path)  # Upload the index file to the output folder

        # Save the list
        local_file = str(modal)  # Name the list file
        local_file_path = os.path.join(tmp_dir, local_file)  # Path to save the list file locally
        with open(local_file_path, "wb") as fp:
            pickle.dump(value["list"], fp)  # Serialize and dump the list to the local file
        output_folder.upload_file(local_file, local_file_path)  # Upload the list file to the output folder
