import json
import base64

import dataiku
import cv2
import numpy as np

import dash
from dash import html
from dash.dependencies import Input, Output
import dash_bootstrap_components as dbc

from project_utils import load_image, process_annotations, add_boxes, get_color_map

image_folder = dataiku.Folder("PRCGY0s7")
df = dataiku.Dataset("test_scored").get_dataframe()

# Names of the columns corresponding to the ground truth and the predictions
# If one is missing present from the dataset, put the empty string (e.g. "label": "")
name = {
    "label": "label",
    "prediction": "prediction"
}
for k in name:
    if name[k] not in df.columns:
        name[k] = ""

cat2id, counter = {}, 0
for i in range(len(df)):
    if len(name["label"]) > 0:
        df.iloc[i].label = json.loads(df.iloc[i][name["label"]])
        for annotation in df.iloc[i][name["label"]]:
            if annotation["category"] not in cat2id:
                cat2id[annotation["category"]] = counter
                counter += 1
    if len(name["prediction"]) > 0:
        df.iloc[i].prediction = json.loads(df.iloc[i][name["prediction"]])
        for annotation in df.iloc[i][name["prediction"]]:
            if annotation["category"] not in cat2id:
                cat2id[annotation["category"]] = counter
                counter += 1

color_map = get_color_map(cat2id, value=0.7)
color_map2 = get_color_map(cat2id)

style_container = {
    "display": "flex",
    "flex-flow": "column",
    "align-items": "center",
    "height": "100vh",
    "max-width": "1200px",
    "max-height": "800px",
    "margin": "10px auto"
}

style_image_box = {
    "flex-grow": "1",
    "width": "100%",
    "background-repeat": "no-repeat",
    "background-size": "contain",
    "background-position": "top center",
    "margin-top": "10px"
}

options = []
if len(name["label"]) > 0: 
    options.append({"label": "Ground truth", "value": 0})
if len(name["prediction"]) > 0: 
    options.append({"label": "Prediction", "value": 1})

checkboxes = dbc.Form(
    dbc.Row([
        dbc.Checklist(
            options=options,
            inline=True,
            value=[x["value"] for x in options],
            id="selection",
        )],
        style={"margin-top": "10px", "display": "none" if len(options) < 2 else "block"}
    )
)

app.config.external_stylesheets =[dbc.themes.ZEPHYR]
app.layout = html.Div(
    [
        html.Div([
            html.H4("Visualization of the predictions (visual model)"),
            checkboxes,
            html.Div(id="debug"),
            dbc.Pagination(id="index", max_value=len(df), fully_expanded=False, first_last=True, previous_next=True, style={"margin-top": "10px"}),
            html.Div(id="output-image", style=style_image_box),            
        ], style=style_container)
    ]
)

def convert_cv2image_to_base64(img):
    is_success, buffer = cv2.imencode(".jpg", img)
    base64_str = base64.b64encode(buffer)
    return base64_str

@app.callback(
    Output("output-image", "style"),
    Input("index", "active_page"),
    Input("selection", "value")
)
def get_annotated_image(index, selection):
    i = 0 if index is None else index - 1
    rectangles, text_boxes = [], []
    img = load_image(image_folder, df.iloc[i]["record_id"])
    size = max([x//500 for x in img.shape[:2]])
    
    if 0 in selection:
        new_rectangles, new_text_boxes = process_annotations(
            df.iloc[i][name["label"]],
            color_map,
            size=size
        )
        rectangles += new_rectangles
        text_boxes += new_text_boxes
    
    if 1 in selection:
        new_rectangles, new_text_boxes = process_annotations(
            df.iloc[i][name["prediction"]],
            color_map2,
            size=size
        )
        rectangles += new_rectangles
        text_boxes += new_text_boxes

    add_boxes(img, rectangles, text_boxes)

    src = convert_cv2image_to_base64(img).decode()
    result = dict(style_image_box)
    result["background-image"] = f'url(data:image/jpeg;base64,{src})'
    return result