import os
import re
import time
import functools
from datetime import datetime
import json
import tempfile

import dataiku
import json
import os
from project_utils import encode_image, retrieve_chunks

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
WEBAPP_NAME = "qanda_webapp" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)
ANSWERS_FOLDER = dataiku.Folder("mPr7AnUg")

NUM_CHUNKS = 6 # Number of retrieved chunks
KB_ID = "zdqno9RF"
project = dataiku.api_client().get_default_project()
kb = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()

folder = dataiku.Folder("vOjkXoGz")

metadata_df = dataiku.Dataset("metadata").get_dataframe()
metadata = {}
for i in metadata_df.index:
    metadata[metadata_df.at[i, "index"]] = json.loads(metadata_df.at[i, "metadata"])

ERROR_MESSAGE_MISSING_KEY = """
LLM connection missing. You need to add it as a user secret. Cf. this project's wiki

Please note that the question answering web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM connection.

You can find examples of answers in the `answers` dataset.
"""

LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
llm_provided = len(LLM_ID) > 0
if llm_provided:
    project = dataiku.api_client().get_default_project()
    llm = project.get_llm(LLM_ID)
    retriever = kb.as_langchain_retriever(search_kwargs={"k": 10})

def get_messages(chunks_with_metadata, question):
    """
    Build the messages send to the multimodal LLM.
    """
    messages = [
        {
            "role": "system",
            "parts": [
                {
                    "type": "TEXT",
                    "text": "You are a helpful assitant. Concisely answer the question of the user based on the facts provided. If you don't know, just say you don't know."
                }
            ]
        },
        {
            "role": "user",
            "parts": [
                {
                    "type": "TEXT",
                    "text": f"Answer the following question: {question}. Use the following facts."
                }
            ]
        }
    ]

    for chunk in chunks_with_metadata:
        if chunk["type"] == "text":
            messages.append(
                {
                    "role": "user",
                    "parts": [
                        {
                            "type": "TEXT",
                            "text": f"Fact: {chunk['content']}"
                        }
                    ]
                }
            )
        else:
            caption = [{
                "type": "TEXT",
                "text": f"Fact: {chunk['caption']}"
            }] if "caption" in chunk else []
            messages.append(
                {
                    "role": "user",
                    "parts": caption + [
                        {
                            "type": "IMAGE_INLINE",
                            "inlineImage": encode_image(folder, chunk['image_url'])
                        }
                    ]
                }
            )
    return messages     

@functools.lru_cache()
def answer_question(question):
    """
    Answer the question using a multimodal RAG approach.
    """
    chunks_with_metadata = retrieve_chunks(retriever, question, metadata, NUM_CHUNKS)
    completion = llm.new_completion()
    completion.cq["messages"] = get_messages(chunks_with_metadata, question)
    completion.settings["maxOutputTokens"] = 300
    completion.settings["temperature"] = 0
    resp = completion.execute()
    return resp.text, chunks_with_metadata

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "align-items": "flex-start",
    "display": "flex",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",  
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "none"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Question answering"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Multimodal Retrieval-Augmented Generation",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        question_bar,   
        html.Div(
            [
                dbc.Spinner(
                    html.Div(
                        id='answer',
                        style={"min-width": "100px"}
                    ),
                    color="primary"
                ),
                html.Div(
                    [
                        html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                        html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
                    ],
                    id="feedback_buttons",
                    style=STYLE_FEEDBACK)
            ],
            style=STYLE_ANSWER
        ),
        html.Div(id='debug'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='answer_and_sources', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

def format_sources(sources):
    result, i = [dcc.Markdown("**Sources**")], 1
    for source in sources:
        if source["type"] == "text":
            result.append(dcc.Markdown(f"{i}. **{source['filename']}** (page {source['page']}): {source['content']}"))
        else:
            result.append(dcc.Markdown(f"{i}. **{source['filename']}** (page {source['page']}): {source['caption'][:1000]}"))
            result.append(html.Div(html.Img(src=f"data:image/jpeg;base64,{encode_image(folder, source['image_url'])}", style={"width": "600px"}), style={"text-align": "center", "margin": "20px 0"}))
        i += 1
    return result

@app.callback(
    Output('answer', 'children'),
    Output('answer_and_sources', 'data'),
    Output('feedback_buttons', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def display_answer(n_clicks, n_submit, query):
    """
    Display the answer
    """
    if len(query) == 0:
        return "", "", STYLE_FEEDBACK, query, 0
    start = time.time()
    if not llm_provided:
        return ERROR_MESSAGE_MISSING_KEY, "", STYLE_ANSWER, query, 0
    style = dict(STYLE_FEEDBACK)
    style["display"] = "flex"
    answer, sources = answer_question(query)
    answer_and_sources = json.dumps({"answer": answer, "sources": sources})
    answer = [
        dcc.Markdown(answer),
        html.P(f"\n{(time.time()-start):.1f} seconds\n\n")
    ] + format_sources(sources)
    return answer, answer_and_sources, style, query, hash(str(start) + query)

@app.callback(
    Output('debug', 'children'),
    Input('answer_and_sources', 'data'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    State('query', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, query, question_id, feedback):
    """
    Log the question and the answer
    """
    if llm_provided and len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with ANSWERS_FOLDER.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "answer": answer,
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "webapp": WEBAPP_NAME,
                    "version": VERSION
                }), "utf-8"))
        else:
            if path in ANSWERS_FOLDER.list_paths_in_partition():
                ANSWERS_FOLDER.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2