# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Imports, constants and inputs

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
import json
import os
import tempfile
import shutil
from llama_index.core import (
    load_index_from_storage,
    StorageContext,
    PromptTemplate,
)
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.langchain.dku_embeddings import DKUEmbeddings
from llama_index.llms.langchain import LangChainLLM
from llama_index.embeddings.langchain import LangchainEmbedding

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
EMBED_ID = dataiku.get_custom_variables()["embed_id"]

df = dataiku.Dataset("questions").get_dataframe()
folder = dataiku.Folder("eK0eluEH")

# convert llm to a llama-index compatible object
llm = LangChainLLM(llm=DKUChatLLM(llm_id=LLM_ID, temperature=0))

# convert the embedding to a llama-index compatible object
embed_model = LangchainEmbedding(DKUEmbeddings(llm_id=EMBED_ID))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Set-up

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Load the retriever

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with tempfile.TemporaryDirectory() as temp_dir_name:

    for filename in folder.list_paths_in_partition():

        local_file_path = temp_dir_name + filename
        if not os.path.exists(os.path.dirname(local_file_path)):
            os.makedirs(os.path.dirname(local_file_path))

        with folder.get_download_stream(filename) as f_remote, open(
            local_file_path, "wb"
        ) as f_local:
            shutil.copyfileobj(f_remote, f_local)
            print("File {} copied to {}".format(filename, f_local))

    # rebuild storage context
    storage_context = StorageContext.from_defaults(persist_dir=temp_dir_name)

    # load index
    loaded_index = load_index_from_storage(storage_context, embed_model=embed_model)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# retrieve node relationships
all_nodes_dict = loaded_index.docstore.docs

# load the retriever and convert to recursive retriever
loaded_retriever = loaded_index.as_retriever(similarity_top_k=10)
loaded_rec_retriever = RecursiveRetriever(
    "vector",
    retriever_dict={"vector": loaded_retriever},
    verbose=False,
    node_dict=all_nodes_dict,
)
loaded_query_engine = RetrieverQueryEngine.from_args(loaded_rec_retriever, llm=llm)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Answer the questions

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ## Adjust the prompt

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
qa_temp = PromptTemplate(
    (
        "Use the following pieces of context to answer the question at the end.\n"
        "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
        "{context_str}\n"
        "Question: {query_str}\n"
        "Helpful Answer:"
    )
)

loaded_query_engine.update_prompts({"response_synthesizer:text_qa_template": qa_temp})

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ## Generate answers

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in df.index:
    result = loaded_query_engine.query(df.at[i, "question"])
    df.at[i, "generated_answer"] = result.response
    df.at[i, "context_with_metadata"] = json.dumps(
        [
            {
                "content": x.text,
                "chunk_id": str(x.metadata["chunk_id"]),
                "url": x.metadata["url"],
            }
            for x in result.source_nodes
        ]
    )
    df.at[i, "context"] = json.dumps([x.text for x in result.source_nodes])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers_parent_children").write_with_schema(df)
