import time
import functools
import dataiku
from datetime import datetime
import json

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
WEBAPP_NAME = "qanda_llm_mesh_webapp" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)

LLM_ID = dataiku.get_custom_variables()["augmented_LLM_id"]
if len(LLM_ID) > 0:
    llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)

KB_ID = LLM_ID.split(":")[1] if ":" in LLM_ID else ""

CONNECTION_AVAILABLE = False
project = dataiku.api_client().get_default_project()
for kb in project.list_knowledge_banks():
    if kb["id"] == KB_ID:
        for embedding_model in project.list_llms(purpose='TEXT_EMBEDDING_EXTRACTION'):
            if embedding_model["id"] == kb["embeddingLLMId"]:
                CONNECTION_AVAILABLE = True
                break
        break    

# Folder to log answers and positive/negative reactions
answers_folder = dataiku.Folder("zhYE6r2h")

ERROR_MESSAGE_MISSING_KEY = """
LLM Connection missing. You need to add it as a project variable. Cf. this project's [wiki](https://gallery.dataiku.com/projects/EX_LLM_STARTER_KIT/wiki/1/Project%20description).

**Please note that the question answering web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM Connection**.

You can find examples of answers in the `answers` dataset.
"""

# Question answering prompt

def escape_markdown(text):
    return text.replace('\\`', '`').replace('\\_', '_')\
        .replace('\\~', '~').replace('\\>', '>')\
        .replace('\\[', '[').replace('\\]', ']')\
        .replace('\\(', '(').replace('\\)', ')')\
        .replace('`', '\\`').replace('_', '\\_')\
        .replace('~', '\\~').replace('>', '\\>').replace('[', '\\[')\
        .replace(']', '\\]').replace('(', '\\(').replace(')', '\\)')

@functools.lru_cache()
def get_answer(query):
    """
    Provide the LLM with the query and chunks extracted from the source documents and get an answer.
    """
    completion = llm.new_completion()
    completion.settings["temperature"] = 0
    completion.settings["maxOutputTokens"] = 200
    completion.with_message("You are a helpful assistant. Answer in a clear, concise and factual way.")
    return completion.with_message(query).execute().text

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "align-items": "flex-start",
    "display": "flex",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",  
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "none"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Question answering"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Question answering over documents",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        dcc.Markdown("Get answers based on the [Global assessment report on biodiversity and ecosystem services (summary for policymakers)](https://www.ipbes.net/global-assessment) and the Wikipedia pages on [Biodiversity](https://en.wikipedia.org/wiki/Biodiversity), [CITES](https://en.wikipedia.org/wiki/CITES) and [IUCN](https://en.wikipedia.org/wiki/International_Union_for_Conservation_of_Nature) (retrieved May 2023)"),
        question_bar,   
        html.Div(
            [
                dbc.Spinner(
                    dcc.Markdown(
                        id='answer',
                        link_target="_blank",
                        style={"min-width": "100px"}
                    ),
                    color="primary"
                ),
                html.Div(
                    [
                        html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                        html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
                    ],
                    id="feedback_buttons",
                    style=STYLE_FEEDBACK)
            ],
            style=STYLE_ANSWER
        ),
        html.Div(id='debug'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

@app.callback(
    Output('answer', 'children'),
    Output('feedback_buttons', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def answer_question(n_clicks, n_submit, query):
    """
    Display the answer
    """
    if len(query) == 0:
        return "", STYLE_FEEDBACK, query, 0
    start = time.time()
    if len(LLM_ID) == 0 or not CONNECTION_AVAILABLE:
        return ERROR_MESSAGE_MISSING_KEY, STYLE_ANSWER, query, 0
    style = dict(STYLE_FEEDBACK)
    style["display"] = "flex"
    answer = get_answer(query)
    answer = f"{answer}\n\n{(time.time()-start):.1f} seconds"
    return answer, style, query, hash(str(start) + query)

@app.callback(
    Output('debug', 'children'),
    Input('answer', 'children'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    State('query', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, query, question_id, feedback):
    """
    Log the question and the answer
    """
    if len(LLM_ID) > 0 and len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with answers_folder.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "answer": answer,
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "webapp": WEBAPP_NAME,
                    "version": VERSION
                }), "utf-8"))
        else:
            if path in answers_folder.list_paths_in_partition():
                answers_folder.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2