# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # 1) Packages

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from dataiku.langchain.dku_llm import DKUChatLLM
from langchain.retrievers import EnsembleRetriever
import pickle
import io
import json

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # 2) Recipe Parameters

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ## Recipe inputs

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
LLM_ID = dataiku.get_custom_variables()["LLM_id"]

# knowledge bank
kb = (
    dataiku.api_client()
    .get_default_project()
    .get_knowledge_bank("YK6IMhfU")
    .as_core_knowledge_bank()
)

# bm25 retriever index
folder = dataiku.Folder("qkyfR2rB")

# questions
df = dataiku.Dataset("questions").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # 3) Set-up the retrieval

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ### BM25 retriever

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
with folder.get_download_stream("/bm25result.pkl") as stream:
    buf = io.BytesIO(stream.read())
    sparse_retriever = pickle.load(buf)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ### Dense retriever (from the knowledge bank)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dense_retriever = kb.as_langchain_retriever(search_kwargs={"k": 5})

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ### Ensemble retriever

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
ensemble_retriever = EnsembleRetriever(
    retrievers=[dense_retriever, sparse_retriever], weights=[0.5, 0.5]
)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # 4) Build question answering chain

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ### Prompt

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}"""
qa_prompt = PromptTemplate.from_template(template)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# ### Question answering chain

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
qa_chain = RetrievalQA.from_chain_type(
    llm,
    retriever=ensemble_retriever,
    chain_type_kwargs={"prompt": qa_prompt},
    return_source_documents=True,
)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # 5) Generate answers

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in df.index:
    answer = qa_chain.invoke({"query": df.at[i, "question"]})
    df.at[i, "generated_answer"] = answer["result"]
    df.at[i, "context_with_metadata"] = json.dumps(
        [
            {
                "content": x.page_content,
                "chunk_id": str(x.metadata["chunk_id"]),
                "url": x.metadata["url"],
            }
            for x in answer["source_documents"]
        ]
    )
    df.at[i, "context"] = json.dumps(
        [x.page_content for x in answer["source_documents"]]
    )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers_hybrid").write_with_schema(df)
