import dataiku
import numpy as np
import io
import torch
from datasets import Dataset
import transformers
import shap

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
import plotly.graph_objs as go

COLOR_POS = (218, 41, 46, 0.8)
COLOR_NEG = (52, 89, 230, 0.8)

df = dataiku.Dataset("test").get_dataframe()
test_ds = Dataset.from_pandas(df)
classes = sorted(list(set(df.label_text)))

### Replace the code block below to replace the SetFit model with another
tokenizer = transformers.AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
folder = dataiku.Folder("Ey4Ge7PK")
with folder.get_download_stream("model.pt") as stream:
    model = torch.load(io.BytesIO(stream.read()), map_location=torch.device('cpu'))
model.model_body._target_device = torch.device('cpu')

def f(x):
    """
    Return a vector of class probabilities given a text input x
    """
    return model.predict_proba(x).numpy()    
### End of the model-specific code block

explainer = shap.Explainer(f, tokenizer, output_names=classes)

def float_to_color(x, thresh=0.25):
    """
    Convert a value (here a Shapley value) into a color.
    """
    if x >= thresh:
        color = COLOR_POS
    elif x >= 0:
        color = (COLOR_POS[0], COLOR_POS[1], COLOR_POS[2], x/thresh*COLOR_POS[3])
    elif x <= -thresh:
        color = COLOR_NEG
    elif x < 0:
        color = (COLOR_NEG[0], COLOR_NEG[1], COLOR_NEG[2], -x/thresh*COLOR_NEG[3])
    return color

def get_explanations(tokens, shap_values):
    """
    Create the table with the Shapley values.
    """
    scores = np.sum(shap_values.values, axis=1) + shap_values.base_values
    values = list(shap_values.values[0, :, np.argmax(scores[0, :])])
    order = np.argsort(scores[0, :])
    
    html_list = [html.H5([
        html.Span("Explanation ("),
        html.Span("positive", style={"background-color": f"rgba{COLOR_POS}", "color": "white"}),
        html.Span(" and "),
        html.Span("negative", style={"background-color": f"rgba{COLOR_NEG}", "color": "white"}),
        html.Span(" Shapley values)")  
    ])]
    table_header, table_body = [], []
    for i in order[::-1]:
        table_body.append(
            html.Tr([
                html.Td(shorten(classes[i]), style={"padding": "0 5px"}),
                html.Td(get_sentence(tokens, list(shap_values.values[0, :, i])))
            ])
        )
    html_list.append(dbc.Table(table_header + table_body))
    return html_list
        
def get_sentence(tokens, values):
    """
    Return the string colored on the basis of the Shapley values
    """
    html_list = []
    for i in range(len(tokens)):
        if tokens[i].endswith(" "):
            span, suffix = tokens[i][:-1], " "
        else:
            span, suffix = tokens[i], ""
        title = f"{span}: " + "{:.3f}".format(values[i])
        html_list.append(html.Span(
            span,
            title=title,
            style={
                "background-color": f"rgba{float_to_color(values[i])}",
                "padding": "0"
            }
        ))
        if len(suffix) > 0:
            html_list.append(" ")
    return html_list

def shorten(s, max_width=40):
    """
    Shorten long labels.
    """
    if len(s) > max_width:
        return s[:max_width] + "..."
    return s

def create_fig(x, y):
    """
    Create a bar chart displaying the number of occurrences of the Top-X elements of the input list.
    """
    order = np.argsort(x)
    x, y = np.array(x)[order], np.array(y)[order]
    text = ["{:.3f}".format(v) for v in x]
    labels = [shorten(s).strip() + " " for s in y]
    data = [go.Bar(
        y=labels,
        x=x,
        text=text,
        orientation='h',
        marker_color='#3459e6',
        hovertemplate=""
    )]
    data[0].hoverinfo = 'none'
    fig = go.Figure(data=data)
    fig.update_layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(t=0, b=0, r=0, l=0),
        height=20*len(x) + 60,
        xaxis=dict(
            visible=False,
            tickmode='linear',
            tick0=0,
            dtick=1
        ),
        yaxis={
            'tickfont': {
                'size': 16,
                'color': 'rgb(73, 80, 87)',
                'family': 'Inter,-apple-system,BlinkMacSystemFont,"Segoe UI"'
            }
        }
    )
    return fig

# Layout

STYLE_DIV = {"margin-top": "20px"}

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Score text')
    ],
    style=STYLE_DIV
)

app.title = "Text classification interactive scoring"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Interactive scoring",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        question_bar,
        html.Div(
            [
                html.H5("Scores"),
                dcc.Graph(
                id="predictions",
                figure=create_fig([1], ["test"]),
                config={'displayModeBar': False},
                ),                
            ],
            id="graph-container",
            style={"display": "none"}
        ),
        dbc.Spinner(html.Div(id="explanation", style=STYLE_DIV))
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

@app.callback(
    Output("predictions", "figure"),
    Output("graph-container", "style"),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def score_text(n_clicks, n_submit, query):
    """
    Classify the input text.
    """
    if len(query) == 0:
        return {}, {"display": "none"}

    scores = model.predict_proba([query]).numpy()[0, :]
    return create_fig(scores, classes), STYLE_DIV
    

@app.callback(
    Output('explanation', 'children'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
)
def explain(n_clicks, n_submit, query):
    """
    Generate the chart with the Shapley values.
    """
    if len(query) == 0:
        return ""
    
    shap_values = explainer([query])
    tokens = list(shap_values.data[0])
    
    return get_explanations(tokens, shap_values)
