import dataiku
from datetime import datetime
import json

import dash
from dash import dcc, html, ctx
from dash.dependencies import Input, Output, State, ALL
import dash_bootstrap_components as dbc

WEBAPP_NAME = "agent_chat_webapp" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)
LLM_ID = "agent:NH0cSHSy:v1"
llm = dataiku.api_client().get_default_project().get_llm(LLM_ID)
log_folder = dataiku.Folder("zhYE6r2h") # Folder to log conversations flagged by users

llm_provided = len(dataiku.get_custom_variables()["LLM_id"]) > 0

ERROR_MESSAGE_MISSING_KEY = """
LLM connection missing. You need to add it as a project variable. Cf. this project's wiki.

Please note that this web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and specifying an LLM connection.
"""

# Layout

STYLE_CONVERSATION = {
    "overflow-y": "auto",
    "display": "flex",
    "height": "calc(90vh - 50px)",
    "flex-direction": "column-reverse",
    "width": "100%"
}

STYLE_MESSAGE = {
    "max-width": "80%",
    "width": "max-content",
    "padding": "5px 10px",
    "border-radius": 10,
    "margin-bottom": 10,
}

STYLE_BUTTON_BAR = {
    "display": "flex",
    "justify-content": "center",
    "gap": "10px",
    "margin-top": "10px"
}

STYLE_ACCORDION = {
    "margin-top": "10px"
}

reset_icon = html.Span(html.I(className="bi bi-trash3"), style=dict(paddingRight="5px"))
flag_icon = html.Span(html.I(className="bi bi-flag"), style=dict(paddingRight="5px"))

button_bar = html.Div(
    [
        dbc.Button(
            html.Span([reset_icon, 'Reset conversation']),
            id='reset-btn',
            title='Delete all previous messages'
        ),
        dbc.Button(
            html.Span([flag_icon, 'Flag conversation']),
            id='flag-btn',
            title='Flag the conversation, e.g. in case of inappropriate or erroneous replies'
        ),
    ],
    style=STYLE_BUTTON_BAR
)

yes_no_buttons = html.Div(
    [
        dbc.Button(
            "Yes",
            id={'type': 'yes-btn', 'index': "yes-btn"},
            title='Accept the proposed action'
        ),
        dbc.Button(
            "No",
            id={'type': 'no-btn', 'index': "no-btn"},
            title='Reject the proposed action'
        ),
    ],
    style=STYLE_BUTTON_BAR
)

send_icon = html.Span(html.I(className="bi bi-send"))

question_bar = html.Div(
    [
        dbc.InputGroup(
            [
                dbc.Input(id='query', value='', type='text', minLength=0),
                dbc.Button(send_icon, id='send-btn', title='Get an answer')
            ],
        ),
        button_bar
    ]
)

conversation = html.Div(
    html.Div(id="conversation"),
    style=STYLE_CONVERSATION
)

spinning_wheel = dbc.Spinner(
    dcc.Markdown(
        id='spinning_wheel',
        style={
            "height": "20px",
            "margin-bottom": "20px"
        }
    ),
    color="primary"
)

app.title = "Chatbot"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

app.layout = html.Div(
    [  
        conversation,
        spinning_wheel,
        question_bar,
        dcc.Store(id="logged"),
        dcc.Store(id='messages', storage_type='memory', data=[]),
        dcc.Store(id='next_state', storage_type='memory', data=None)
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

def textbox(text, box="AI", accordion="", buttons=False):
    """
    Create the text box corresponding to a message.
    """
    style = dict(STYLE_MESSAGE)
    if box == "user":
        style["margin-left"] = "auto"
        style["margin-right"] = 0
        color, inverse = "primary", True
    else:
        style["margin-left"] = 0
        style["margin-right"] = "auto"
        color, inverse = "light", False
    content = [html.Div(text)]
    if len(accordion) > 0:
        content.append(
            dbc.Accordion(
                [
                    dbc.AccordionItem(
                        dcc.Markdown(accordion),
                        title="Actions taken",
                    )
                ],
                start_collapsed=True,
                style=STYLE_ACCORDION
            )
        )
    return dbc.Card(
        html.Div(content),
        style=style,
        body=True,
        color=color,
        inverse=inverse
    )

def get_function_call(dictionary):
    """
    Create a string representation of a function call.
    """
    name = dictionary["additional_kwargs"]["tool_calls"][0]["function"]["name"]
    arguments = json.loads(dictionary["additional_kwargs"]["tool_calls"][0]["function"]["arguments"])
    return f"{name}({', '.join([k+'='+str(arguments[k]) for k in arguments])})"    

@app.callback(
    Output("conversation", "children"),
    Input("messages", "data"),
    prevent_initial_call=True
)
def update_display(messages):
    """
    Display the messages of the conversation.
    """
    if len(messages) == 0:
        return []
    result, function_call, function_call_result = [], None, None
    
    # We display a box for human messages and AI messages without tool call.
    for i in range(len(messages)):
        construct, dictionary = messages[i]
        box = "user" if construct == "HumanMessage" else "AI"
        if construct == "AIMessage":
            if "tool_calls" in dictionary["additional_kwargs"]:
                function_call = get_function_call(dictionary)
                continue
            elif function_call is not None:
                result.append(textbox(dictionary["content"], box=box, accordion=function_call+" = "+function_call_result))
                function_call, function_call_result = None, None
                continue
        elif construct == "ToolMessage":
            function_call_result = dictionary["content"]
            continue
        result.append(textbox(dictionary["content"], box=box))
    if construct == "AIMessage" and len(dictionary["content"]) == 0:
        text = [
            html.Div(html.B("Do you accept the following action?")),
            html.Div(function_call),
            yes_no_buttons
        ]
        result.append(textbox(text, box="AI"))
    return result

@app.callback(
    Output('messages', 'data'),
    Output('next_state', 'data'),
    Output('query', 'value'),
    Output('spinning_wheel', 'children'),
    Input('reset-btn', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input({'type': 'yes-btn', 'index': ALL}, 'n_clicks'),
    Input({'type': 'no-btn', 'index': ALL}, 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
    State('messages', 'data'),
    State('next_state', 'data'),
    prevent_initial_call=True
)
def receive_query(reset, n_clicks, yes_click, no_click, n_submit, message, messages, state):
    """
    Receive the new input from the user.
    """
    if not llm_provided:
        return [("HumanMessage", {"content": ERROR_MESSAGE_MISSING_KEY})], "", "", ""
    if ctx.triggered_id == "reset-btn":
        return [], None, "", ""
    try:
        if ctx.triggered_id.type == "yes-btn":
            message = "Yes"
        elif ctx.triggered_id.type == "no-btn":
            message = "No"
    except AttributeError:
        pass

    inputs = {"query": message}
    if state is not None:
        inputs["state"] = state
        
    completion = llm.new_completion()
    completion.with_message(
        json.dumps(inputs)
    )
    resp = completion.execute()
    output = json.loads(resp.text)
    
    return output["messages"], output["state"], "", ""

@app.callback(
    Output('logged', 'data'),
    Input('flag-btn', 'n_clicks'),
    Input('messages', 'data'),
)
def log_conversation(n_clicks, messages):
    """
    Log the current conversation.
    """
    if len(messages) > 0 and ctx.triggered_id == "flag-btn":
        path = f"/{hash(str(messages))}.json"
        with log_folder.get_writer(path) as w:
            w.write(bytes(json.dumps({
                "messages": messages,
                "timestamp": str(datetime.now()),
                "webapp": WEBAPP_NAME,
                "version": VERSION
            }), "utf-8"))
        return path
    else:
        return ""

@app.callback(
    Output('reset-btn', 'disabled'),
    Output('flag-btn', 'disabled'),
    Output('send-btn', 'disabled'),
    Input('messages', 'data'),
    Input('logged', 'data'),
    Input('query', 'value')
)
def disable_buttons(messages, flagged, query):
    """
    Disable buttons when appropriate.
    """
    if len(messages) == 0:
        disable_reset = True
        disable_flag = True
    else:       
        disable_reset = False
        disable_flag = True if flagged != "" else False
    if len(query) == 0 or (len(messages) > 0 and messages[-1][0] == "AIMessage" and len(messages[-1][1]["content"]) == 0):
        disable_send = True
    else:
        disable_send = False
    return disable_reset, disable_flag, disable_send