# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import logging
import dataiku
import numpy as np
from transformers import AutoTokenizer, AutoModel
from project_utils import compute_embeddings, save, load, normalize

BATCH_SIZE = 16

embeddings_folder = dataiku.Folder("bwli327B")

model_name = dataiku.get_custom_variables()["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

id_label, text_label = "id", "content"
df = dataiku.Dataset("chunks").get_dataframe().set_index(id_label)

for i in range(len(df)):
    df.iloc[i].content = df.iloc[i].title + df.iloc[i].content
    # Include the word "plugin" at the beginning of the chunks corresponding to Dataiku plugins
    # (this provides additional context for the retrieval of relevant passages of the documentation)
    # The two lines below can be removed for other collections of documents
    if df.iloc[i].href.startswith("https://www.dataiku.com/product/plugins/"):
        df.iloc[i].content = "Plugin: " + df.iloc[i].content

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import re

# You can remove the function below if your text doesn't include code
pattern = re.compile(r'\b\w[\.\w]*[\.][\.\w]*\w\b')
def expand_strings_with_dots(s):
    """
    Replace substrings like xxx.yyy.zzz with xxx.yyy.zzz xxx yyy zzz
    (this helps the semantic search for class and method names)
    """
    return re.sub(pattern, lambda x: x.group() + " " + x.group().replace(".", " "), s)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Compute embeddings
list_paths = embeddings_folder.list_paths_in_partition()
if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths:
    ids = load(embeddings_folder, "ids.npy")
    emb = load(embeddings_folder, "embeddings.npy")
    still_valid = [i for i in range(len(ids)) if ids[i] in df.index]
    logging.info(f"{len(still_valid)} embeddings kept")
    logging.info(f"{len(ids) - len(still_valid)} embeddings discarded")
    ids, emb = ids[still_valid], emb[still_valid, :]
    save(embeddings_folder, "embeddings.npy", emb)
    save(embeddings_folder, "ids.npy", ids)
    df = df[~df.index.isin(ids)]
else:
    logging.info("No existing embeddings")

if len(df) > 0:
    dim_embeddings = int(compute_embeddings(model, tokenizer, [""]).shape[1])
    emb = np.empty((len(df), dim_embeddings), dtype=np.float32)
    i = 0
    while i < len(df):
        if i % (100 * BATCH_SIZE) == 0:
            logging.info(f"Embedding computation: step {i}")
        end = min(i + BATCH_SIZE, len(df))
        emb[i:end, :] = compute_embeddings(
            model, tokenizer, [expand_strings_with_dots(df.iloc[j][text_label]) for j in range(i, end)]
        )
        i += BATCH_SIZE

    ids = np.array(df.index)
    if "/ids.npy" in list_paths and "/embeddings.npy" in list_paths:
        previous_ids = load(embeddings_folder, "ids.npy")
        ids = np.concatenate((previous_ids, ids))
        previous_emb = load(embeddings_folder, "embeddings.npy")
        emb = np.concatenate((previous_emb, emb), axis=0)
    logging.info(f"{len(df)} embeddings computed")
    save(embeddings_folder, "embeddings.npy", normalize(emb))
    save(embeddings_folder, "ids.npy", ids)
    logging.info(f"{len(ids)} embeddings in total")
else:
    logging.info("No additional embedding computed")