# -*- coding: utf-8 -*-
import io
import dataiku
import numpy as np
import faiss

from project_utils import load, normalize

embeddings = dataiku.Folder("P4SttKJS")
faiss_index = dataiku.Folder("FpWcIx1Z")

corpus_embeddings = normalize(load(embeddings, "embeddings.npy"))

d = corpus_embeddings.shape[1]
nlist = 4 * int(np.sqrt(corpus_embeddings.shape[0]))
m = d // 16
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
index.train(corpus_embeddings)
index.add(corpus_embeddings)

with io.BytesIO() as buf:
    writer = faiss.PyCallbackIOWriter(buf.write)
    faiss.write_index(index, writer)
    faiss_index.upload_data("index.index", buf.getvalue())
