import dataiku
import time
from datetime import datetime
import json

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

LOG_ALL_ANSWERS = False # Whether all answers or only answers with user feedback should be logged
WEBAPP_NAME = "agent_webapp" # Name of the app (logged when an answer is flagged)
VERSION = "1.0" # Version of the app (logged when an answer is flagged)

# Folder to log answers and positive/negative reactions
FOLDER = dataiku.Folder("zhYE6r2h")

customers_df = dataiku.Dataset("customers").get_dataframe()

LLM_ID = "agent:2k5v5N3O:v1"
llm_provided = len(dataiku.get_custom_variables()["LLM_id"]) > 0
if llm_provided:
    llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)
    

ERROR_MESSAGE_MISSING_KEY = """
LLM Connection missing. You need to add it as a user secret. Cf. this project's wiki

**Please note that this web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM connection**.

You can find examples of answers in the `requests_processed` dataset.
"""

def process_request(request, customer_id=None):
    """
    Handle the request of the customer, taking into account his/her customer id if provided.
    """   
    completion = llm.new_completion()
    completion.with_message(
        json.dumps(
            {
                "request": request,
                "customer_id": customer_id            
            }
        )
    )
    resp = completion.execute()
    result = json.loads(resp.text)
    return result["reply"], result["actions"], result["intermediate_steps"]

# Layout

STYLE_ANSWER = {
    "margin-top": "20px",
    "align-items": "flex-start",
    "display": "flex",
    "height": "auto"
}

STYLE_BUTTON = {
    "width": "20px",  
    "text-align": "center",
    "margin": "0px 5px",
}

STYLE_FEEDBACK = {
    "margin-left": "10px",
    "display": "none"
}

STYLE_STEPS = {
    "white-space": "pre-line"
}

ok_icon_fill = html.Span(html.I(className="bi bi-emoji-smile-fill"))
nok_icon_fill = html.Span(html.I(className="bi bi-emoji-frown-fill"))
ok_icon = html.Span(html.I(className="bi bi-emoji-smile"))
nok_icon = html.Span(html.I(className="bi bi-emoji-frown"))

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get an answer')
    ],
    style = {"margin-top": "20px"}
)

app.title = "Question answering"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

options_context = [
    {
        "label": f"Customer #{customers_df.iloc[i].id} ({customers_df.iloc[i].title} {customers_df.iloc[i].first_name} {customers_df.iloc[i].last_name})",
        "value": customers_df.iloc[i].id
    }
    for i in range(len(customers_df))
] + [{"label": "All customers", "value": -1}]
options_display = [
    {"label": "Reply only", "value": 0},
    {"label": "Reply and intermediate steps", "value": 1},
]

app.layout = html.Div(
    [
        html.H4(
            "LLM agent",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        dbc.Form(
            dbc.Row(
                [
                    dbc.Label("Context", width="auto"),
                    dbc.Col(dbc.Select(
                        id='customer_id',
                        options=options_context,
                        value=1,
                    ), className="me-3"),
                    dbc.Label("Display", width="auto"),
                    dbc.Col(dbc.Select(
                        id='display_mode',
                        options=options_display,
                        value=1,
                    ), className="me-3"),
                ],
                className="g-2"
            ),
            style = {"margin-top": "20px"}
        ),
        question_bar,   
        html.Div(
            [
                dbc.Spinner(
                    dcc.Markdown(
                        id='answer',
                        link_target="_blank"
                    ),
                    color="primary",
                    spinner_style={"display": "none"}
                ),
                html.Div(
                    [
                        html.A(ok_icon, id="link_ok", href="#", style=STYLE_BUTTON),
                        html.A(nok_icon, id="link_nok", href="#", style=STYLE_BUTTON)   
                    ],
                    id="feedback_buttons",
                    style=STYLE_FEEDBACK
                )
            ],
            style=STYLE_ANSWER
        ),

        dbc.Spinner(
            dcc.Markdown(
                id='intermediate_steps',
                link_target="_blank",
                style=STYLE_STEPS
            ),
            color="primary"
        ),
        html.Div(id='debug'),
        dcc.Store(id='feedback', data=2, storage_type='memory'),
        dcc.Store(id='question', storage_type='memory'),
        dcc.Store(id='question_id', storage_type='memory'),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

@app.callback(
    Output('answer', 'children'),
    Output('intermediate_steps', 'children'),
    Output('feedback_buttons', 'style'),
    Output('question', 'data'),
    Output('question_id', 'data'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('customer_id', 'value'),
    State('query', 'value'),
)
def answer_question(n_clicks, n_submit, customer_id, query):
    """
    Display the answer
    """
    if len(query) == 0:
        return "", "", STYLE_FEEDBACK, query, 0
    start = time.time()
    if not llm_provided:
        return "", ERROR_MESSAGE_MISSING_KEY, STYLE_FEEDBACK, query, 0
    style = dict(STYLE_FEEDBACK)
    style["display"] = "flex"
    customer_id = None if customer_id == "-1" else customer_id
    answer, _, actions_details = process_request(
        query,
        customer_id=int(customer_id) if customer_id is not None else None
    )
    import logging
    logging.info(actions_details)
    answer = answer.replace("[", "\[").replace("]", "\]")
    resources = f"\n\n{(time.time()-start):.1f} seconds"
    return f"** {answer} **{resources}", actions_details, style, query, hash(str(start) + query)

@app.callback(
    Output('intermediate_steps', 'style'),
    Input('display_mode', 'value'),
    prevent_initial_call=True
)
def display_intermediate_steps(mode):
    """
    Toggle the display of the intermediate steps
    """
    style = dict(STYLE_STEPS)
    style["display"] = "block" if mode == "1" else "none"
    return style

@app.callback(
    Output('debug', 'children'),
    Input('answer', 'children'),
    Input('link_ok', 'children'),
    Input('link_nok', 'children'),
    Input('intermediate_steps', 'children'),
    State('query', 'value'),
    State('question_id', 'data'),
    State('feedback', 'data'),
)
def log_answer(answer, ok, nok, intermediate_steps, query, question_id, feedback):
    """
    Log the question and the answer
    """
    if llm_provided and len(answer) > 0:
        path = f"/{str(question_id)}.json"
        if LOG_ALL_ANSWERS or feedback in [1, -1]:
            with FOLDER.get_writer(path) as w:
                w.write(bytes(json.dumps({
                    "question": query,
                    "answer": f"Reply: {answer}\n\nIntermediate steps: {intermediate_steps}",
                    "feedback": 0 if feedback == 2 else feedback,
                    "timestamp": str(datetime.now()),
                    "version": VERSION,
                    "webapp": WEBAPP_NAME
                }), "utf-8"))
        else:
            if path in FOLDER.list_paths_in_partition():
                FOLDER.delete_path(path)
    return ""

@app.callback(
    Output('link_ok', 'children'),
    Output('link_nok', 'children'),
    Input('feedback', 'data')
)
def update_icons(value):
    """
    Update the feedback icons when the user likes or dislikes an answer
    """
    ok = ok_icon_fill if value is not None and value == 1 else ok_icon
    nok = nok_icon_fill if value is not None and value == -1 else nok_icon
    return ok, nok

@app.callback(
    Output('feedback', 'data'),
    Input('link_ok', 'n_clicks'),
    Input('link_nok', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('feedback', 'data'),
    State('query', 'value'),
    State('question', 'data'),
)
def provide_feedback(ts_ok, ts_nok, click, submit, value, question, previous_question):
    """
    Record the feedback of the user
    """
    triggered = dash.ctx.triggered_id
    if triggered == "link_ok":
        return 1 if value != 1 else 0
    elif triggered == "link_nok":
        return -1 if value != -1 else 0
    return value if question == previous_question else 2