import os
import time
import functools
import dataiku
import numpy as np
import pandas as pd
from datetime import datetime
import json

from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast

from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

from project_utils import load, compute_embeddings, normalize

N_RESULTS = 10 # Number of chunks retrieved
N_SOURCES = 5 # Max number of sources displayed to the user
TOLERANCE_SCORE_SOURCE = 0.1 # Threshold to determine which sources are displayed
MAX_TOKENS = 2000 # Maximum number of tokens in the prompt
LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
VERSION = dataiku.get_custom_variables()["version"] # Version of the app (logged with the answers)

# OpenAI credentials (to be added as a user secret)
# Cf. https://doc.dataiku.com/dss/latest/security/user-secrets.html
openai_key_provided = False
auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        openai_key_provided = True
        os.environ["OPENAI_API_KEY"] = secret["value"]
ERROR_MESSAGE_MISSING_KEY = "\n\nOpenAI key missing. You need to add it as a user secret. Cf. this project's wiki"
        
# Chunks from the knowledge base
df = dataiku.Dataset("chunks").get_dataframe().set_index("id")
embeddings = dataiku.Folder("bwli327B")
corpus_embeddings = load(embeddings, "embeddings.npy")

# Semantic similarity model
model_name = dataiku.get_custom_variables()["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

# Tokenizer to comply with maximum prompt size for GPT3
tokenizer2 = GPT2TokenizerFast.from_pretrained("gpt2")

# Folders to log answers and positive/negative reactions
answers_folder = dataiku.Folder("jI8G2N4I")
feedback_folder = dataiku.Folder("DkbTwRbt")

# Question answering prompt
prompt_template = """Use the following extracts of the Dataiku DSS documentation to answer the question at the end.
If you don't know the answer, just say that you don't know.
---------
Dataiku DSS documentation:
{context}
---------
Question: I am a Dataiku DSS user. {question}
Answer: """
PROMPT = PromptTemplate(
    template=prompt_template,
    input_variables=["context", "question"]
)

if openai_key_provided:
    chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT)

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "display": "none",
    "align-items": "flex-start",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "flex"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Dataiku FAQ"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        question_bar,   
        html.Div([
            dbc.Spinner(
                dcc.Markdown(
                    id='answer',
                    link_target="_blank"
                ),
                color="primary"
            ),
            html.Div([
                html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
            ], style=STYLE_FEEDBACK)
        ],
            id="answer_div", 
            style=STYLE_ANSWER
        ),
        dbc.Spinner(
            dcc.Markdown(
                id='sources',
                link_target="_blank"
            ),
            color="primary"
        ),
        html.Div(id='debug'),
        html.Div(id='debug2'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

def get_number_tokens(s):
    """
    Compute the number of tokens corresponding to a string
    """
    return len(tokenizer2(s)["input_ids"])
n_tokens_template = get_number_tokens(prompt_template)

@functools.lru_cache
def get_answer(query):
    """
    Get the answer to the user's question and the passages provided to GPT-3
    """
    query_embedding = normalize(compute_embeddings(model, tokenizer, [query]))
    scores = (1 + (query_embedding@np.transpose(corpus_embeddings))[0, :])/2
    order = np.argsort(-scores)
    remaining_tokens = MAX_TOKENS - n_tokens_template - get_number_tokens(query)
    excerpts, selected = [], []
    for i in range(N_RESULTS):
        excerpt = df.iloc[order[i]].content
        n_tokens_excerpt = get_number_tokens(excerpt)
        if n_tokens_excerpt < remaining_tokens:
            excerpts.append(excerpt)
            remaining_tokens -= n_tokens_excerpt
            selected.append(order[i])
    docs = [Document(page_content=x) for x in excerpts]
    result = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
    return result["output_text"], selected

def add_link(paragraph, link):
    """
    Add the link associated with a paragraph to its first line
    """
    header, *content = paragraph.split("\n")
    return "\n".join([f"[{header}]({link})"] + content)

def get_sources(answer, order):
    """
    Among the chunks added in the prompt, keep only those most similar to the answer
    """
    excerpts_df = df.iloc[order[:N_RESULTS]]
    answer_embedding = normalize(compute_embeddings(model, tokenizer, [answer]))
    excerpt_embeddings = corpus_embeddings[order[:N_RESULTS], :]
    scores2 = (1 + (answer_embedding@np.transpose(excerpt_embeddings))[0, :])/2
    order2 = np.argsort(-scores2)
    n_sources = min(sum(scores2 > max(scores2) - TOLERANCE_SCORE_SOURCE), N_SOURCES)
    excerpts_df2 = excerpts_df.iloc[order2[:n_sources]]
    excerpts_df2.content = excerpts_df2.title + "\n" + excerpts_df2.content
    excerpts2 = list(excerpts_df2.content)
    paragraphs = []
    sources = []
    for i in range(len(excerpts2)):
        excerpt = excerpts2[i]
        header, *content = excerpt.split("\n")
        for row in content:
            paragraphs.append(f"{header}\n\n{row}")
            sources.append(i)
    
    paragraph_embeddings = normalize(compute_embeddings(model, tokenizer, paragraphs))
    scores3 = (1 + (answer_embedding@np.transpose(paragraph_embeddings))[0, :])/2
    order3 = np.argsort(-scores3)

    already_cited = set()
    results = []
    for i in range(len(order3)):
        j = order3[i]
        source, paragraph = sources[j], paragraphs[j]
        if source not in already_cited:
            already_cited.add(source)
            results.append(add_link(paragraph, excerpts_df2.iloc[source].href))

    return results

def post_process_answer(answer):
    """
    Improve the format of the answer
    """
    if "**" in answer:
        return answer
    return "\n".join([f"**{x}**" for x in answer.split("\n")])

def post_process_sources(sources):
    """
    Convert a list of sources into a string in Markdown format
    """
    result = f"Source{'s' if len(sources) > 1 else ''}:\n\n"
    already_seen = set()
    for i in range(len(sources)):
        if sources[i] not in already_seen:
            already_seen.add(sources[i])
            result += f"{i+1}. {sources[i]}\n\n"
    return result

@app.callback(
    Output('answer', 'children'),
    Output('sources', 'children'),
    Output('answer_div', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def answer_question(n_clicks, n_submit, query):
    """
    Display the answer and the sources
    """
    if len(query) == 0:
        return "", "", STYLE_ANSWER, query, 0
    start = time.time()
    if not openai_key_provided:
        return "", ERROR_MESSAGE_MISSING_KEY, STYLE_ANSWER, query, 0
    style = dict(STYLE_ANSWER)
    style["display"] = "flex"
    answer, order = get_answer(query)
    sources = get_sources(answer, order)
    answer = post_process_answer(answer)
    sources = post_process_sources(sources)
    delay = f"\n\nDelay: {(time.time()-start):.1f} seconds"
    return answer, sources + delay, style, query, hash(str(start) + query)

@app.callback(
    Output('debug2', 'children'),
    Input('answer', 'children'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    State('sources', 'children'),
    State('query', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, sources, query, question_id, feedback):
    """
    Log the question, the answer and the sources
    """
    if len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with answers_folder.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "answer": f"{answer}\n\n{sources}",
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "version": VERSION
                }), "utf-8"))
        else:
            if path in answers_folder.list_paths_in_partition():
                answers_folder.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2