import dataiku
import random
import logging
import json

import dash
from dash import html, dcc, ctx
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

DEBUG = False

df = dataiku.Dataset("answers_stacked").get_dataframe()
output_dataset = dataiku.Dataset("preferences")
output_dataset.spec_item["appendMode"] = True

approaches = sorted(list(set(df["approach"])))
questions = list(set(df["question_id"]))
questions = {}
reference_answers = {}
for i in df.index:
    questions[df.loc[i, "question_id"]] = df.loc[i, "question"]
    reference_answers[df.loc[i, "question_id"]] = df.loc[i, "reference_answer"]

generated_answers = {q: {} for q in questions}

def get_sources(i):
    result = "\n\n**Sources**"
    contexts = json.loads(df.at[i, "context_with_metadata"])
    for context in contexts:
        if len(result) > 13: 
            result += "\n\n---\n\n"
        else:
            result += "\n\n"
        result += f"{context['content'].strip()} [link]({context['url']})"
    return result

for i in df.index:
    generated_answers[df.loc[i, "question_id"]][df.at[i, "approach"]] = f"{df.at[i, 'generated_answer']}" + get_sources(i)
preferences = {k: [] for k in questions}

comparisons = {k: {a: {} for a in approaches} for k in set(df["question_id"])}
try:
    comparisons_df = dataiku.Dataset("preferences").get_dataframe()
    for i in comparisons_df.index:
        question_id, current_approach, new_approach, result = [
            comparisons_df.at[i, c]
            for c in ["question_id", "current_approach", "new_approach", "result"]
        ]   
        comparisons[question_id][current_approach][new_approach] = result
except:
    pass
        
for question_id in questions:
    random.seed(question_id)
    order = list(range(len(approaches)))
    random.shuffle(order)
    preferences[question_id] = [approaches[order[0]]]

def print_current_order():
    result = ""
    for question_id in questions:
        result += "\n\n" + f"**{questions[question_id]}**" + "\n\n"
        result += " < ".join([f"{generated_answers[question_id][x]} ({x})" for x in preferences[question_id]]) + "\n\n"
    return result
    
def get_next_comparison(start=False):
    for question_id in questions:
        if len(preferences[question_id]) == len(approaches):
            continue
        random.seed(question_id)
        order = list(range(len(approaches)))
        random.shuffle(order)
        for i in order[1:]:
            new_approach = approaches[i]
            if new_approach in preferences[question_id]:
                continue
            for j in range(len(preferences[question_id])):
                current_approach = preferences[question_id][j]
                if current_approach in comparisons[question_id]:
                    if new_approach in comparisons[question_id][current_approach]:
                        result = comparisons[question_id][current_approach][new_approach]
                        if start:
                            if process_comparison(question_id, current_approach, j, new_approach, result, save=False):
                                break
                    else:
                        return (question_id, current_approach, j, new_approach)
    return (-1, "", 0, "")

def process_comparison(question_id, current_approach, j, new_approach, result, save=True):
    row = {
        "question_id": question_id,
        "current_approach": current_approach,
        "new_approach": new_approach,
        "result": result
    }
    if save:
        with output_dataset.get_writer() as w:
            w.write_row_dict(row)
        comparisons[question_id][current_approach][new_approach] = result
    if new_approach not in preferences[question_id]:
        if result == 1:
            if j == len(preferences[question_id]) - 1:
                preferences[question_id].append(new_approach)
                return True
        else:
            preferences[question_id] = preferences[question_id][:j] + [new_approach] + preferences[question_id][j:]
            return True
    return False

comparison = get_next_comparison(start=True)

# Display

STYLE_CARD = {
    "width": "49%"
}

STYLE_QUESTION = {"margin-top": "20px"}

STYLE_ANSWER = {
    "text-align": "left",
    "margin-top": "20px"
}

STYLE_ANSWERS = {
    "display": "flex",
    "justify-content": "space-evenly"
}

answer_cards = html.Div(
    [
        dbc.Card(dbc.CardBody(
            [
                html.Span(dbc.Button("Select this answer", id="left_button")),
                dcc.Markdown(id="left_answer", style=STYLE_ANSWER)
            ],
        ), style=STYLE_CARD, class_name="text-center"),
        dbc.Card(dbc.CardBody(
            [
                dbc.Button("Select this answer", id="right_button"),
                dcc.Markdown(id="right_answer", style=STYLE_ANSWER)
            ]
        ), style=STYLE_CARD, class_name="text-center"),
    ],
    style=STYLE_ANSWERS,
    id="answers"
)

app.title = "Answer comparison"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.Div(id="question", style=STYLE_QUESTION),
        answer_cards,
        dcc.Markdown(id="debug", style=STYLE_QUESTION),
        dcc.Store(data=comparison, id="current_comparison")
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "1200px"
    }
)

# Callbacks

def emphasize(s):
    """
    Return an italicised version of the input string.
    """
    return "\n".join([f"*{x}*" if len(x) > 0 else "" for x in s.split("\n")])

@app.callback(
    Output('left_answer', 'children'),
    Output('right_answer', 'children'),
    Output('question', 'children'),
    Output('answers', 'style'),
    Output('debug', 'children'),
    Input('current_comparison', 'value'),
)
def update_display(comparison):
    question_id, current_approach, j, new_approach = comparison
    if question_id == -1:
        return (
            "",
            "",
            dcc.Markdown("All answers have been assessed. Thanks!"),
            {"display": "none"},
            print_current_order() if DEBUG else ""
        )
    return (
        generated_answers[question_id][new_approach],
        generated_answers[question_id][current_approach],
        [
            dcc.Markdown(
                "*Given the question and its reference answer below, select the best proposed answer*\n\n---\n\n"
                + f"**{questions[question_id]}**"
                + "\n\n"
                + emphasize(f"Reference answer: {reference_answers[question_id]}")
            )
        ],
        STYLE_ANSWERS,
        print_current_order() + "\n\n" + str(comparisons) if DEBUG else ""
    )

@app.callback(
    Output('current_comparison', 'value'),
    Input('left_button', 'n_clicks'),
    Input('right_button', 'n_clicks'),
    State('current_comparison', 'value'),
)
def update_comparison(left_n_clicks, right_n_clicks, comparison):
    if comparison is None:
        comparison = get_next_comparison(start=True)
    else:
        outcome = 1 if ctx.triggered_id == "left_button" else 0
        process_comparison(*comparison, outcome)
        comparison = get_next_comparison()
    return comparison