import os
import re
import time
import functools
from datetime import datetime
import json
import tempfile

import dataiku
import json
import os
import torch

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
WEBAPP_NAME = "qanda_webapp_colpali" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)
ANSWERS_FOLDER = dataiku.Folder("mPr7AnUg")
INDEX_FOLDER = dataiku.Folder("8jv9ZTHM")
NUM_CHUNKS = 3 # Number of retrieved chunks
index_name = "docs"


ERROR_MESSAGE_MISSING_KEY = """
LLM connection or GPU missing. Cf. this project's wiki.

Please note that the question answering web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM connection.

You can find examples of answers in the `answers_colpali` dataset.
"""

LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
llm_provided = len(LLM_ID) > 0
if llm_provided and torch.cuda.is_available():
    # Get a handle on a multimodal LLM (e.g. GPT-4o)
    llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)
    
    with tempfile.TemporaryDirectory() as temp_dir:
        index_directory = os.path.join(temp_dir, index_name)
        os.mkdir(index_directory)
        for f in INDEX_FOLDER.list_paths_in_partition():
            path = f.split("/")
            if len(path) == 3:
                target_directory = os.path.join(index_directory, path[1])
                try:
                    os.mkdir(target_directory)
                except FileExistsError:
                    pass
            with INDEX_FOLDER.get_download_stream(f) as stream:
                filepath = os.path.join(index_directory, *path[1:])
                with open(filepath, "wb") as f2:
                    f2.write(stream.read())

        # Load the index
        from byaldi import RAGMultiModalModel
        RAG = RAGMultiModalModel.from_index(index_name, index_root=temp_dir)
        id2filename = RAG.get_doc_ids_to_file_names()
        id2filename = {k: id2filename[k].split("/")[-1] for k in id2filename}


@functools.lru_cache()
def answer_question(question):
    """
    Answer the question with a multimodal RAG approach based on ColPali.
    """
    images = RAG.search(question, k=NUM_CHUNKS)

    sources_displayed = [html.B("Sources")]
    completion = llm.new_completion()
    
    # Include the user question and the retrieved images in the LLM conversation
    completion.with_message(f"Concisely answer the following question: {question}. Use the documents attached below.")
    sources = []
    for image in images:
        filename = id2filename[image["doc_id"]]
        sources.append(f'{filename} (page {image["page_num"]})')
        mp_message = completion.new_multipart_message()
        mp_message.with_text(sources[-1])
        mp_message.with_inline_image(image["base64"])
        mp_message.add()
        
        sources_displayed.append(
            html.P(html.I(sources[-1]))
        )
        image_uri = f'data:image/jpeg;base64,{image["base64"]}'
        sources_displayed.append(
            html.Img(src=image_uri, style={"width": "100%"})
        )
    
    completion.settings["maxOutputTokens"] = 300
    completion.settings["temperature"] = 0
    return completion.execute().text, sources, sources_displayed

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "align-items": "flex-start",
    "display": "flex",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",  
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "none"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Question answering"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Multimodal Retrieval-Augmented Generation",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        question_bar,   
        html.Div(
            [
                dbc.Spinner(
                    html.Div(
                        id='answer',
                        style={"min-width": "100px"}
                    ),
                    color="primary"
                ),
                html.Div(
                    [
                        html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                        html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
                    ],
                    id="feedback_buttons",
                    style=STYLE_FEEDBACK)
            ],
            style=STYLE_ANSWER
        ),
        html.Div(id='debug'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='answer_and_sources', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

@app.callback(
    Output('answer', 'children'),
    Output('answer_and_sources', 'data'),
    Output('feedback_buttons', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def display_answer(n_clicks, n_submit, query):
    """
    Display the answer
    """
    if len(query) == 0:
        return "", "", STYLE_FEEDBACK, query, 0
    start = time.time()
    if not llm_provided:
        return ERROR_MESSAGE_MISSING_KEY, "", STYLE_ANSWER, query, 0
    style = dict(STYLE_FEEDBACK)
    style["display"] = "flex"
    answer, sources, sources_displayed = answer_question(query)
    answer_and_sources = {"answer": answer, "sources": sources}
    answer = [
        dcc.Markdown(answer),
        html.P(f"\n{(time.time()-start):.1f} seconds\n\n")
    ] + sources_displayed
    return answer, answer_and_sources, style, query, hash(str(start) + query)

@app.callback(
    Output('debug', 'children'),
    Input('answer_and_sources', 'data'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    State('query', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, query, question_id, feedback):
    """
    Log the question and the answer
    """
    if llm_provided and len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with ANSWERS_FOLDER.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "answer": answer,
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "webapp": WEBAPP_NAME,
                    "version": VERSION
                }), "utf-8"))
        else:
            if path in ANSWERS_FOLDER.list_paths_in_partition():
                ANSWERS_FOLDER.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2