# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import os
import tempfile
import time
from datetime import datetime
import json

import dataiku
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.langchain.dku_embeddings import DKUEmbeddings

from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
from langchain.prompts.prompt import PromptTemplate

NUM_CHUNKS = 5 # Number of retrieved chunks

folder = dataiku.Folder("4pbb7xKD")
df = dataiku.Dataset("questions").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# LLM to generate the answer
LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm = DKUChatLLM(
    llm_id=LLM_ID,
    temperature=0
)

# Embedding model and vector store to retrieve relevant chunks
EMBEDDING_MODEL_ID = dataiku.get_custom_variables()["embedding_model_id"]
EXPERIMENT_NAME = "question_answering_code"
embeddings = DKUEmbeddings(llm_id=EMBEDDING_MODEL_ID)
with tempfile.TemporaryDirectory() as temp_dir:
    for f in folder.list_paths_in_partition():
        with folder.get_download_stream(f) as stream:
            with open(os.path.join(temp_dir, os.path.basename(f)), "wb") as f2:
                f2.write(stream.read())
    retriever = FAISS.load_local(
        temp_dir,
        embeddings,
        allow_dangerous_deserialization=True
    ).as_retriever(num_chunks=NUM_CHUNKS)

# Prompt

prompt = PromptTemplate(
    input_variables=["context", "question"],
    template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer:"""
)

# Retrieval-augmented generation chain
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain_from_docs = (
    RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
    | prompt
    | llm
    | StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in df.index:
    question = df.at[i, "question"]
    response = rag_chain_with_source.invoke(question)
    df.at[i, "answer"] = response["answer"]
    df.at[i, "context"] = "\n\n".join([f"- {x.metadata['source']}: {x.page_content}" for x in response["context"]])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers2").write_with_schema(df)