# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
from dataiku.langchain.dku_embeddings import DKUEmbeddings
import io
import os
import logging
import collections
import tempfile
from langchain.document_loaders import PyPDFLoader
from langchain.document_loaders import UnstructuredFileLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS

folder = dataiku.Folder("TyR7HVoz")
output_folder = dataiku.Folder("4pbb7xKD")

EMBEDDING_MODEL_ID = dataiku.get_custom_variables()["embedding_model_id"]
embeddings = DKUEmbeddings(llm_id=EMBEDDING_MODEL_ID)

CHUNK_SIZE = 1000
CHUNK_OVERLAP = 100

text_splitter = CharacterTextSplitter(
    separator = "\n",
    chunk_size = CHUNK_SIZE,
    chunk_overlap  = CHUNK_OVERLAP,
    length_function = len,
)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
paths = folder.list_paths_in_partition()
filenames = set([os.path.basename(f) for f in paths])

to_keep = set()

# Load the index of documents (if it has already been built)
if len(output_folder.list_paths_in_partition()) > 0:
    with tempfile.TemporaryDirectory() as temp_dir:
        for f in output_folder.list_paths_in_partition():
            with output_folder.get_download_stream(f) as stream:
                with open(os.path.join(temp_dir, os.path.basename(f)), "wb") as f2:
                    f2.write(stream.read())
        index = FAISS.load_local(temp_dir, embeddings, allow_dangerous_deserialization=True)
        to_remove = []
        logging.info(f"{len(index.docstore._dict)} vectors loaded")

        to_remove = []
        for idx, doc in index.docstore._dict.items():
            source = doc.metadata["source"].split(" ")[0].split(",")[0]
            if source in filenames:
                # Identify documents already indexed and still present in the source folder
                to_keep.add(source)
            else:
                # Identify documents removed from the source folder
                to_remove.append(idx)
        if len(to_remove) > 0:
            index.delete(to_remove)
        logging.info(f"{len(to_remove)} vectors removed")
else:
    index = None

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
docs = []

to_add = [f for f in paths if os.path.basename(f) not in to_keep]

for path in to_add:
    with folder.get_download_stream(path) as stream:
        with io.BytesIO(stream.read()) as buffer:
            with tempfile.NamedTemporaryFile() as temp_file:
                temp_file.write(buffer.getvalue())
                # Select the appropriate document loader
                if path.endswith(".pdf"):
                    loader = PyPDFLoader(temp_file.name)
                else:
                    loader = UnstructuredFileLoader(temp_file.name)
                # Load the documents and split them in chunks
                chunks = loader.load_and_split(text_splitter=text_splitter)
                counter, counter2 = collections.Counter(), collections.Counter()
                filename = os.path.basename(path)
                # Define a unique id for each chunk
                if "page" in chunks[0].metadata:
                    for chunk in chunks:
                        counter[chunk.metadata['page']] += 1
                    for i in range(len(chunks)):
                        counter2[chunks[i].metadata['page']] += 1
                        chunks[i].metadata['source'] = f"{filename}, page {chunks[i].metadata['page'] + 1}"
                        if counter[chunks[i].metadata['page']] > 1:
                            chunks[i].metadata['source'] += f" ({counter2[chunks[i].metadata['page']]}"
                            chunks[i].metadata['source'] += f"/{counter[chunks[i].metadata['page']]})"
                else:
                    if len(chunks) == 1:
                        chunks[0].metadata['source'] = filename
                    else:
                        for i in range(len(chunks)):
                            chunks[i].metadata['source'] = f"{filename} ({i+1}/{len(chunks)})"
                docs += chunks

if len(docs) > 0:
    if index is not None:
        index.add_documents(documents=docs)
    else:
        index = FAISS.from_documents(docs, embeddings)
logging.info(f"{len(docs)} vectors added")
        
logging.info(f"{len(index.docstore._dict)} vectors in the index")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with tempfile.TemporaryDirectory() as temp_dir:
    index.save_local(temp_dir)
    for f in os.listdir(temp_dir):
        output_folder.upload_file(f, os.path.join(temp_dir, f))