import dataiku
import pandas as pd
import numpy as np

from dash import Dash, dash_table, callback, dcc
from dash.dependencies import Input, Output
import dash_html_components as html
import dash_bootstrap_components as dbc
from dash.dash_table.Format import Format, Scheme, Trim

df = dataiku.Dataset("answers_evaluated").get_dataframe()
q33_bert_score = np.percentile(df["BERT score/score"], 33.33)
q67_bert_score = np.percentile(df["BERT score/score"], 66.66)

df.drop(
    [
        "context_with_metadata"
    ]
    + [
        c
        for c in df.columns
        if "_flag" in c
    ],
    axis=1,
    inplace=True
)

df2 = df.drop(
    ["context", "reference_answer", "num_tokens_context", "num_tokens_answer"]
    + [c for c in df.columns if "/justification" in c],
    axis=1
)

tooltip_df = pd.DataFrame(columns=df2.columns)
tooltip_df["question"] = df["question"]
for c in df2.columns:
    if "/v1" in c:
        for i in df.index:
            tooltip_df.loc[i, c] = f"**Justification**: {df.loc[i, c.replace('score', 'justification')]}"
    elif c == "question":
        for i in df.index:
            tooltip_df.loc[i, c] = "**Question**: "+df.loc[i, "question"]+"\n\n**Reference answer**: "+df.loc[i, "reference_answer"]
    elif c == "generated_answer":
        for i in df.index:
            tooltip_df.loc[i, c] = "**Answer**: "+df.loc[i, "generated_answer"]#+"\n\n**Context**:\n\n"+df.loc[i, "context"]
    else:
        tooltip_df[c] = [""]*len(df2)

df2 = df2.rename(mapper=lambda s: s.capitalize().replace("_", " ").replace("/v1", "").replace("/score", ""), axis=1)
tooltip_df = tooltip_df.rename(mapper=lambda s: s.capitalize().replace("_", " ").replace("/v1", "").replace("/score", ""), axis=1)


# Layout

app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        dash_table.DataTable(
            data=df2.to_dict('records'),
            tooltip_data=[
                {
                    column: {'value': str(value), 'type': 'markdown'}
                    for column, value in row.items()
                } for row in tooltip_df.to_dict('records')
            ],
            tooltip_delay=0,
            tooltip_duration=None,
            columns=[
                {
                    "name": i,
                    "id": i,
                    "type": "text" if i in ["Question", "Generated answer"] else "numeric",
                    "format": (
                        Format(precision=3, scheme=Scheme.fixed, trim=Trim.yes)
                        if type(df2[i].iloc[0]) == np.float64
                        else Format()
                    )
                }
                for i in df2.columns
            ],
            filter_action="native",
            filter_options={"placeholder_text": "Filter column..."},
            sort_action="native",
            page_size=10,
            css=[{
                'selector': '.dash-table-tooltip',
                'rule': "font-size:14px"
            }],
            style_data={
                "lineHeight": "15px"
            },
            style_cell={
                "whiteSpace": "normal",
                "height": "auto",
                "textAlign": "right",
                "padding": "5px",
                "font-size": "14px",
                "font-family": "sans-serif",
                "textAlign": "center",
            },
            style_cell_conditional=[
                {
                    "if": {"column_id": "Question"},
                    "textAlign": "left",
                    "width": "25%"
                },
                {
                    "if": {'column_id': 'Generated answer'},
                    "textAlign": "left",
                    "width": "40%"
                },
                {
                    "if": {'state': 'active'},
                    'backgroundColor': "rgb(214, 222, 250)",
                    'border': "1px solid rgb(214, 222, 250)",
                }
            ] + [
                {
                    'if': {
                        'filter_query': '{{{col}}} > 3'.format(col=col),
                        'column_id': col
                    },
                    'backgroundColor': "rgb(213, 240, 230)",
                } for col in ["Answer correctness", "Answer relevance", "Relevance", "Faithfulness"]              
            ] + [
                {
                    'if': {
                        'filter_query': '{{{col}}} = 3'.format(col=col),
                        'column_id': col
                    },
                    'backgroundColor': 'rgb(253, 242, 223)',
                } for col in ["Answer correctness", "Answer relevance", "Relevance", "Faithfulness"]              
            ] + [
                {
                    'if': {
                        'filter_query': '{{{col}}} < 3'.format(col=col),
                        'column_id': col
                    },
                    'backgroundColor': 'rgb(248, 212, 213)',
                } for col in ["Answer correctness", "Answer relevance", "Relevance", "Faithfulness"]              
            ] + [
                {
                    'if': {
                        'filter_query': '{Bert score} >= ' + str(q67_bert_score),
                        'column_id': "Bert score"
                    },
                    'backgroundColor': 'rgb(213, 240, 230)',
                },
                {
                    'if': {
                        'filter_query': '{Bert score} < ' + str(q67_bert_score),
                        'column_id': "Bert score"
                    },
                    'backgroundColor': 'rgb(253, 242, 223)',
                },
                {
                    'if': {
                        'filter_query': '{Bert score} < ' + str(q33_bert_score),
                        'column_id': "Bert score"
                    },
                    'backgroundColor': 'rgb(248, 212, 213)',
                } 
            ] + [
                {
                    'if': {
                        'filter_query': '{{{col}}} < 0'.format(col=col),
                        'column_id': col
                    },
                    'backgroundColor': '#fdfdfd',
                } for col in ["Bert score", "Answer correctness", "Answer relevance", "Relevance", "Faithfulness"]              
            ],
            style_header={
                "fontWeight": "bold"
            },
            id='tbl'
        ),
        html.Div(
            dbc.Card(
                dbc.CardBody(dcc.Markdown(id="answer")),
                style={"max-width": "800px", "margin": "20px auto"}
            ),
            id="container"
        )
    ],
    style={"max-width": "1400px", "margin": "20px auto"}
)

@callback(
    Output('answer', 'children'),
    Output('container', 'style'),
    Input('tbl', 'active_cell'),
    Input('tbl', 'page_current'),
    Input('tbl', 'page_size'),
)
def display_answer(active_cell, page_current, page_size):
    if not active_cell:
        return "", {"display": "none"}
    current = page_current if page_current else 0
    idx = active_cell["row"]+ page_size * current
    question = df.iloc[idx]["question"]
    ground_truth = df.iloc[idx]["reference_answer"]
    generated_answer = df.iloc[idx]["generated_answer"]
    context = df.iloc[idx]["context"]
    return (
        f"**Question**: {df.iloc[idx]['question']}" + "\n\n"
        + f"**Reference answer**: {df.iloc[idx]['reference_answer']}" + "\n\n"
        + f"**Generated answer** ({int(df.iloc[idx]['num_tokens_answer'])} tokens): {df.iloc[idx]['generated_answer']}" + "\n\n"
        + f"**Context** ({int(df.iloc[idx]['num_tokens_context'])} tokens): " + "\n\n" + df.iloc[idx]["context"] + "\n\n"
    ), {"display": "block"}