# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Imports, constants and inputs

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
from llama_index.core import Settings
from llama_index.core.schema import TextNode
from llama_index.core import VectorStoreIndex
from dataiku.langchain.dku_llm import DKUChatLLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import IndexNode
from dataiku.langchain.dku_llm import DKUChatLLM
from llama_index.llms.langchain import LangChainLLM
import os
import faiss
from llama_index.core import StorageContext
import tempfile
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.embeddings.langchain import LangchainEmbedding
from dataiku.langchain.dku_embeddings import DKUEmbeddings

EMBED_ID = dataiku.get_custom_variables()["embed_id"]

df = dataiku.Dataset("chunks").get_dataframe()
folder = dataiku.Folder("eK0eluEH")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Set-up

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ## LLM and Embedding model

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Convert the Embedding connection into a llama-index compatible object
Settings.embed_model = LangchainEmbedding(DKUEmbeddings(llm_id=EMBED_ID), embed_batch_size=10)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Create parent nodes

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# we use a parent id to keep track of the indexing
parent_nodes = [
    TextNode(
        text=df.at[i, "chunk"],
        id_=f"parent_{i}",
        metadata={key: str(df.at[i, key]) for key in ["chunk_id", "url"]},
    )
    for i in df.index
]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Create children nodes

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
sub_chunk_size = 256
sub_chunk_overlap = 30
sub_node_parser = SentenceSplitter(
    chunk_size=sub_chunk_size, chunk_overlap=sub_chunk_overlap
)

all_nodes = []
for base_node in parent_nodes:
    sub_nodes = sub_node_parser.get_nodes_from_documents([base_node])
    sub_inodes = [
        IndexNode.from_text_node(node=sn, index_id=base_node.node_id)
        for sn in sub_nodes
    ]
    # we rename the ids of the children for more clarity
    for j, sub_in in enumerate(sub_inodes):
        sub_in.id_ = "child_{}_from_{}".format(j, base_node.node_id)
    all_nodes.extend(sub_inodes)
    # also add original node to node
    original_node = IndexNode.from_text_node(base_node, base_node.node_id)
    all_nodes.append(original_node)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Save as Faiss index

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex(all_nodes, embed_model=Settings.embed_model)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
folder_name = "index/"

with tempfile.TemporaryDirectory() as temp_dir:
    vector_index.storage_context.persist(persist_dir=temp_dir)

    for filename in os.listdir(temp_dir):
        file_path = os.path.join(temp_dir, filename)
        with open(file_path, "rb") as f:
            folder.upload_stream(filename, f)
