import io
import json
import re
import pickle
from collections import Counter

import dataiku

import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State, ALL
import dash_bootstrap_components as dbc
import plotly.graph_objs as go

from project_utils import normalize, load, compute_embeddings

USE_FAISS = False
EXACT_SEARCH = True
N_RESULTS = 20
MAX_BAR_CHART_ROWS = 5

id_label = dataiku.get_custom_variables()["id_label"]
text_label = dataiku.get_custom_variables()["text_label"]
df = dataiku.Dataset("data").get_dataframe().set_index(id_label)
embeddings = dataiku.Folder("P4SttKJS")

if EXACT_SEARCH:
    exact_search_index = dataiku.Folder("evJsZfu6")
    with exact_search_index.get_download_stream("index.pickle") as stream:
        index_words = pickle.load(io.BytesIO(stream.read()))

if USE_FAISS:
    import faiss
    faiss_index = dataiku.Folder("FpWcIx1Z")
    with faiss_index.get_download_stream("index.index") as stream:
        reader = faiss.PyCallbackIOReader(stream.read)
        index = faiss.read_index(reader)
        index.nprobe = index.nlist//5
else:
    corpus_embeddings = normalize(load(embeddings, "embeddings.npy"))

corpus_ids = load(embeddings, "ids.npy")
inverse_corpus_ids = {corpus_ids[k]: k for k in range(len(corpus_ids))}

# Load model and tokenizer
model_name = dataiku.get_custom_variables()["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

# CSS styles
PAGE_STYLE = {
    'max-width': '1200px',
    'margin': 'auto',
    'text-align': 'center'
}

SIDEBAR_STYLE = {
    "position": "fixed",
    "top": 0,
    "left": 0,
    "bottom": 0,
    "width": "24rem",
    "padding": "2rem 1rem",
    "background-color": "#f8f9fa",
}

FILTER_COMPONENT_STYLE = {
    'margin-bottom': '10px',
}

CONTENT_STYLE = {
    "margin-left": "26rem",
    "padding": "2rem 1rem"
}

SEARCH_BAR_STYLE = {
    'margin': 'auto',
    'text-align': 'center'
}

RESULTS_STYLE = {
    'text-align': 'justify',
    'margin-top': '10px'
}

#### Edit below to adjust filters
filter_components = []

df["year"] = [int(x[:4]) for x in df["date"]]
possible_years = set(df["year"])
min_year, max_year = min(possible_years), max(possible_years)

filter_components.append(
    html.Div(dcc.RangeSlider(
        id='year_range',
        min=min_year,
        max=max_year,
        step=1,
        marks={i:str(i) for i in range(min_year, max_year + 1)},
        value=[min_year, max_year],
    ), style=FILTER_COMPONENT_STYLE)
)

categories = set()
for row in df["category"]:
    for category in row.split(", "):
        categories.add(category.strip())
categories = sorted(list(categories))
categories_df = dataiku.Dataset("categories").get_dataframe()
labels = {categories_df['code'].iloc[i]: categories_df['category'].iloc[i] for i in range(len(categories_df))}

def shorten(string):
    """
    Create a shortened version of a label to be included in a dropdown box
    """
    s = string.strip()
    return s if len(s) <= 36 else s[:33] + '...'

filter_components.append(
    html.Div(dcc.Dropdown(
        id='categories_selected',
        options=[{'label': shorten(labels[x]), 'value': x} for x in categories],
        value=[],
        placeholder="Select one or several categories",
        multi=True,
    ), style=FILTER_COMPONENT_STYLE)
)

organizations = sorted([x for x in set(df["organization"]) if x == x])

filter_components.append(
    html.Div(dcc.Dropdown(
        id='orgs_selected',
        options=[{'label': shorten(x), 'value': x} for x in organizations],
        value=[],
        placeholder="Select one or several organizations",
        multi=True,
    ), style=FILTER_COMPONENT_STYLE)
)

filter_inputs = [Input(x.children.id,'value') for x in filter_components]

def filter(results, *args):
    """
    Keep only the results corresponding to the filters
    """
    year_range, categories_selected, orgs_selected = args
    if (year_range[0] != min_year) or (year_range[1] != max_year):
        results = results[results["year"].apply(lambda x: (x >= year_range[0]) and (x <= year_range[1]))]
    if len(categories_selected) > 0:
        results = results[results["category"].apply(lambda x: len(set(categories_selected).intersection(x.split(', '))) > 0)]
    if len(orgs_selected) > 0:
        results = results[results["organization"].isin(orgs_selected)]
    return results

def no_results_displayed(*args):
    """
    In case of an empty query, determine when no results are displayed
    """
    year_range, categories_selected, orgs_selected = args
    return len(categories_selected) == 0 and len(orgs_selected) == 0

def no_filter(*args):
    """
    Determine when no filter is applied
    (in which case the results are just the N_RESULTS documents with the best scores)
    """
    year_range, categories_selected, orgs_selected = args
    return (
        len(categories_selected) == 0
        and len(orgs_selected) == 0
        and year_range[0] == min_year
        and year_range[1] == max_year
    )

#### Edit below to adjust the display of charts

def create_fig(l):
    """
    Create a bar chart displaying the number of occurrences of the Top-X elements of the input list
    """
    x, y = [], []
    count = Counter(l)
    for key, value in count.most_common()[:MAX_BAR_CHART_ROWS][::-1]:
        y.append(shorten(str(key)) + ' ')
        x.append(value)
    fig = go.Figure([go.Bar(y=y, x=x, text=x, orientation='h', marker_color='#3459e6')])
    fig.update_layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(t=60, b=0, r=10, l=10),
        height=30*len(x) + 60,
        xaxis = dict(
            visible = False,
            tickmode = 'linear',
            tick0 = 0,
            dtick = 1
        )
    )
    return fig

num_figs = 2 # Should be the same as the length of the output of create_figs below

def create_figs(results):
    """
    Derive the figures from the results dataframe
    """
    categories = []
    for i in range(len(results)):
        for category in results.iloc[i]["category"].split(", "):
            categories.append(shorten(labels[category.strip()]))
    years = []
    for i in range(len(results)):
        years.append(results.iloc[i]["year"])
    return [create_fig(categories), create_fig(years)]

fig_components = [
    dcc.Graph(
        id=f"bar-{i}",
        figure={},
        config={'displayModeBar': False})
    for i in range(num_figs)
]

fig_outputs = [Output(f"bar-{i}", "figure") for i in range(num_figs)]

#### Edit below to adjust the display of results
def format_results(results):
    """
    Display the content of each individual document matching the query
    """
    return [html.P([
        html.Span(html.B(f"{str(i + 1)}. {results.iloc[i]['title']} ")),
        dbc.Badge(results.iloc[i]['organization'], color="primary", className="ml-1", style={'font-size': '16px'}) if results.iloc[i]['organization'] != 'nan' else '',
        html.Span(f" ({results.iloc[i]['date'][:4]}): {results.iloc[i][text_label]}"),
        html.Br(),
        html.Span(html.I(' / '.join([shorten(labels[x.strip()]) for x in results.iloc[i]['category'].split(',')]) + ' ')),
        html.A('Similar results', id={'type': 'similar', 'index': str(results.iloc[i].name)}, href='#')
    ]) for i in range(len(results))]

#### No need to edit the code below

# Layout

download_icon = html.Span(html.I(className="bi bi-cloud-arrow-down-fill"))
download_button = [html.Span([
    dbc.Button(download_icon, id="download-button"),
    dcc.Download(id="download-csv")
])]

sidebar = html.Div(
    filter_components + [
        html.Div(download_button + fig_components, id='graph-container', style={"display": "none"})
    ],
    style=SIDEBAR_STYLE,
)

search_icon = html.Span(html.I(className="bi bi-search"))
input_group = dbc.InputGroup(
    [
        dbc.Input(id='my-input', value='', type='text'),
        dbc.Button(search_icon, id='submit-val')
    ]
)

search_bar = html.Div(
    [
        html.H4('Semantic search'),
        input_group,
        html.Br()
    ],
    style=SEARCH_BAR_STYLE
)

results_section = html.Div(id='my-output', style=RESULTS_STYLE)
memory = dcc.Store(id='memory', storage_type='memory')

main_section = html.Div(
    [
        search_bar,
        results_section,
        memory
    ],
    style=CONTENT_STYLE
)

app.config.external_stylesheets =[dbc.themes.ZEPHYR, dbc.icons.BOOTSTRAP]
app.title = "Semantic search"
app.layout = html.Div(
    [
        sidebar,
        main_section
    ],
    style=PAGE_STYLE
)       

# Search functions

def compute_exact_search_score(query):
    """
    Compute the exact search score
    """
    result = np.zeros(len(df))
    words = re.findall(r"\b\w+\b", query)
    if len(words) == 0:
        return result
    for w in set([word.lower() for word in words]):
        if w in index_words:
            for i in index_words[w]:
                result[i] = 1
    return result/len(words)

def compute_similarity_search_score(query_embeddings, n_results):
    """
    Compute the semantic search score
    """
    if USE_FAISS:
        scores = -np.ones(len(df))
        D, I = index.search(query_embeddings, n_results)
        for i in range(I.shape[1]):
            if I[0, i] != -1:
                scores[I[0, i]] = 1 - D[0, i]**2/2
        return (1 + scores)/2
    else:
        return (1 + (query_embeddings@np.transpose(corpus_embeddings))[0, :])/2

def compute_rank(query_embeddings, query, n_results=None):
    """
    Order the documents by similarity with the query vector
    """
    n_results = len(df) if n_results is None else n_results
    scores = compute_similarity_search_score(query_embeddings, n_results)
    if EXACT_SEARCH and len(query) > 0:
        scores += compute_exact_search_score(query)
    return np.argsort(scores)[-1:-n_results-1:-1]

def get_vector(i):
    """
    Get the vector representation of the ith document in the corpus
    """
    if USE_FAISS:
        return index.reconstruct_n(i, 1)
    else:
        return corpus_embeddings[i:(i + 1), :]

# Search callback
@app.callback(
    Output('my-output', 'children'),
    *fig_outputs,
    Output("graph-container", "style"),
    Output('my-input', 'value'),
    Output('memory', 'data'),
    Input('submit-val', 'n_clicks'),
    Input('my-input', 'n_submit'),
    *filter_inputs,
    Input({'type': 'similar', 'index': ALL}, 'n_clicks_timestamp'),
    State('my-input', 'value'),
)
def search(n_clicks, n_submit, *args):
    """
    Compute the search results
    """
    values = args[-2]
    new_query, query_text = args[-1], args[-1]
    query_id = None

    for i in range(len(values)):
        # Case of a click on a "similar results" link
        if values[i] is not None:
            query_id = int(dash.callback_context.inputs_list[-1][i]['id']['index'])
            idx = inverse_corpus_ids[query_id]
            query_embeddings = get_vector(idx)
            new_query = f'similar:{str(query_id)}'
            query_text = ""
            break
    else:
        # Case of an empty query, without filters applied
        if len(query_text) == 0 and no_results_displayed(*args[:-2]):
            return [''] + [{}]*num_figs + [{'display':'none'}, '', "{}"]
        # Case of the search of results similar to a given line
        if 'similar:' in query_text:
            try:
                query_id = query_text.split('similar:')[1]
                idx = inverse_corpus_ids[int(query_id)]
                query_embeddings = get_vector(idx)
                query_text = ""
            # Case of an unknown id
            except ValueError:
                return [[html.P('No results')]] + [{}]*num_figs + [{'display':'none'}, new_query, "{}"]
        # General case of a non-empty query
        else:
            query_embeddings = compute_embeddings(model, tokenizer, [str(query_text)])
            query_embeddings = normalize(query_embeddings)
    
    # Computation of the relevance score and reordering of the dataset
    if no_filter(*args[:-2]):
        ranks = compute_rank(query_embeddings, query_text, n_results=N_RESULTS)
        results = df.reindex([corpus_ids[x] for x in ranks])
    else:
        ranks = compute_rank(query_embeddings, query_text)
        results = df.reindex([corpus_ids[x] for x in ranks])
        results = filter(results, *args[:-2]).iloc[:N_RESULTS]
    
    # Display of results
    if len(results) > 0:
        output = format_results(results)
        visibility = {'display':'block'}
    else:
        output = [html.P('No results')]
        visibility = {'display': 'none'}
    return [output] + create_figs(results) + [visibility, new_query, results.to_json(orient='columns')]

@app.callback(
    Output("download-csv", "data"),
    Input("download-button", "n_clicks"),
    State("memory", "data"),
    prevent_initial_call=True,
)
def download_results(n_clicks, memory):
    """
    Generate a CSV file to be downloaded by the user
    """
    output_df = pd.DataFrame.from_dict(json.loads(memory))
    return dcc.send_data_frame(output_df.to_csv, "data.csv")