import dataiku
import json
import io
import numpy as np
from functools import lru_cache
from PIL import Image, ImageDraw
import torch

from transformers import OwlViTProcessor, OwlViTForObjectDetection

import plotly.express as px
import plotly.graph_objs as go
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc

from project_utils import get_color_map

SAVE_DISABLED = True # Change this to False if you want users to be able to save predictions

folder = dataiku.Folder("Tmf77vDr")
output_folder = dataiku.Folder("sr6n0lYE")
paths = sorted(folder.list_paths_in_partition())

model_name = dataiku.get_custom_variables()["owlvit_model_name"]
processor = OwlViTProcessor.from_pretrained(
    model_name,
    torch_dtype=torch.float16
)
model = OwlViTForObjectDetection.from_pretrained(model_name)
_ = model.eval()

def get_bounding_boxes(texts, image, k=1, threshold=0):
    """
    Get bounding boxes corresponding to the objects provided as texts
    """
    inputs = processor(text=[texts], images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    target_sizes = torch.Tensor([image.size[::-1]])
    results = processor.post_process(outputs=outputs, target_sizes=target_sizes)

    boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
    max_results = min(k, sum(scores >= threshold))
    idx = np.argsort(-scores.cpu().numpy())[:max_results]
    return list(np.round(boxes[idx].cpu().numpy()).astype(int)), list(labels[idx].cpu().numpy())

@lru_cache()
def get_image(image_idx):
    """
    Provide the image corresponding to a certain index, and its width
    """
    path = paths[image_idx]
    with folder.get_download_stream(path) as f:
        buf = io.BytesIO(f.read())
        image = Image.open(buf)
    w, h = image.size
    return image.resize((800, h*800//w)), w

def get_figure(image_idx):
    """
    Return the figure corresponding to an image whose index is provided
    """
    image, _ = get_image(image_idx)
    fig = px.imshow(image)
    fig.update_layout(dragmode=False, margin=dict(l=0, r=0, t=30, b=0))
    fig.update_traces(hoverinfo='none', hovertemplate=None)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    return fig

## CSS styles

STYLE_COMPONENT = {
    "margin": "10px auto 0px auto",
    "max-width": "650px"
}

STYLE_GROUP = {
    "margin-top": "20px"
}

STYLE_DOWNLOAD_BUTTON = STYLE_COMPONENT.copy()
STYLE_DOWNLOAD_BUTTON["display"] = "none"

STYLE_SLIDER = {
    "width": "45%",
}

STYLE_CONTAINER = {
    "display": "flex",
    "justify-content": "space-evenly",
    "margin": "10px auto 0px auto",
    "max-width": "650px"
}

STYLE_NOTIFICATION = {
    "position": "absolute",
    "top": 10,
    "left": "50%",
    "transform": "translateX(-50%)",
    "width": 250,
    "z-index": "10"
}

STYLE_PAGE = {
    "margin": "auto",
    "max-width": "800px",
    "text-align": "center"
}

## Input bar

search_icon = html.Span(html.I(className="bi bi-search"))
search_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(search_icon, id='search-btn', title='Detect object')
    ],
    style=STYLE_COMPONENT
)

num_results_slider = html.Div(
    [
        html.Label("Maximum number of results"),
        dcc.Slider(
            1, 20, 1,
            value=1,
            marks={
              i: {"label": str(i)} for i in [1] + list(range(5, 25, 5))
            },
            id='num-results'
        )
    ],
    style=STYLE_SLIDER
)

threshold_slider = html.Div(
    [
        html.Label("Threshold"),
        dcc.Slider(
            0, 0.3, 0.01,
            value=0,
            marks={
              i/100: {"label": str(i/100) if i > 0 else "0.0"} for i in [0] + list(range(5, 35, 5))
            },
            id='threshold'
        )       
    ],
    style=STYLE_SLIDER
)

inputs = html.Div(
    [
        html.Div(
            dbc.Pagination(
                id="image-idx",
                active_page=1,
                max_value=len(paths),
                first_last=True,
                previous_next=True
            ),
            style=STYLE_CONTAINER
        ),
        html.Div(
            [
                search_bar,
                html.Div(
                    [
                        num_results_slider,
                        threshold_slider
                    ],
                    style=STYLE_CONTAINER
                )                
            ],
            style=STYLE_GROUP
        ),
        dcc.Store(id='detected', storage_type='memory'),
        dcc.Store(id='detected2', storage_type='memory'),
        dcc.Store(id='image-width', storage_type='memory'),
    ]
)

## Outputs

download_icon = html.Span(html.I(className="bi bi-cloud-arrow-down-fill"))

outputs = html.Div(
    [
        dcc.Graph(
            id="graph-picture",
            figure=get_figure(0),
            config={
                "displayModeBar": True,
                "modeBarButtons": [['eraseshape']],
                "displaylogo": False
            },
            style=STYLE_COMPONENT
        ),
        html.Div(
            [
                html.Div(
                    dbc.RadioItems(
                        options=[],
                        value="#ffffff",
                        id="label",
                        inline=True,
                        label_style={"margin": "0 10px 0 2px"}
                    ),
                    style=STYLE_COMPONENT
                ),
                dbc.Button(
                    download_icon,
                    id="download-btn",
                    style=STYLE_DOWNLOAD_BUTTON
                ),      
            ],
            style=STYLE_GROUP
        ),
        dbc.Toast(
            "",
            id="notification",
            header_style={"display": "none"},
            duration=1000,
            style=STYLE_NOTIFICATION,
            is_open=False,
        )
    ]
)

## Overall layout

app.title = "Few-shot object detection"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        inputs,
        outputs,
    ],
    style=STYLE_PAGE
)

def restrict_box(box, w, h):
    """
    Restrict the bounding box coordinates to be within the image dimensions (w,h).
    """
    xmin, ymin, xmax, ymax = box
    xmin = min(max(0, xmin), w-1)
    xmax = min(max(0, xmax), w-1)
    ymin = min(max(0, ymin), h-1)
    ymax = min(max(0, ymax), h-1)
    return [xmin, ymin, xmax, ymax]
    
def remove_redundant_boxes(predictions, w, h):
    """
    Remove redundant bounding boxes by checking their overlap with the previously kept boxes.
    """
    kept = {}
    for k in predictions:
        for box1 in [restrict_box(b, w, h) for b in predictions[k]]:
            if k not in kept:
                kept[k] = [box1]
            else:
                for box2 in kept[k]:
                    if not keep_box(box1, box2):
                        break
                else:
                    kept[k].append(box1)
    return kept

def keep_box(box1, box2, threshold_iou=0.8, threshold_iob=0.8):
    """
    Decide whether if box1 should be kept given that box2 has already been selected.
    """
    xmin1, ymin1, xmax1, ymax1 = box1
    xmin2, ymin2, xmax2, ymax2 = box2
    if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
        return True
    xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3]
    ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3]
    area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter)
    area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
    area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
    if area_inter / (area1 + area2 - area_inter) > threshold_iou:
        return False
    return area_inter / area1 <= threshold_iob

def get_annotation(x0, y0, x1, y1, color):
    """
    Return a rectangle corresponding to the coordinates and color provided.
    """
    return {
        "editable": True,
        "xref": "x",
        "yref": "y",
        "layer": "above",
        "opacity": 1,
        "line": {
          "color": color,
          "width": 2,
          "dash": "solid"
        },
        "fillcolor": "rgba(0,0,0,0)",
        "fillrule": "evenodd",
        "type": "rect",
        "x0": x0,
        "y0": y0,
        "x1": x1,
        "y1": y1
    }

@app.callback(
    Output('detected', 'data'),
    Output('label', 'options'),
    Output('label', 'value'),
    Output('download-btn', 'style'),
    Output('image-width', 'data'),
    Input('search-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('num-results', 'value'),
    Input('threshold', 'value'),
    Input('image-idx', 'active_page'),
    State('query', 'value'),
    prevent_initial_call=True
)
def detect_object(n_clicks, n_submit, k, threshold, page, query):
    """
    Predict the bounding boxes.
    """
    image_idx = page - 1
    output = {'predictions': {}, 'color': {}, 'idx': image_idx}
    style = STYLE_DOWNLOAD_BUTTON.copy()
    
    # Case of an empty query
    if len(query) == 0:
        return json.dumps(output), [], "", style, 0
    
    queries = [x.strip() for x in query.split(",")]
    classes, prompts = [], []
    
    # Parse the query (classes are separated by "," and the labels can be added after ":")
    # Example: laptop:computer,desktop:computer,mouse
    for i in range(len(queries)):
        if ":" in queries[i]:
            splitted = queries[i].split(":")
            classes.append(":".join(splitted[1:]))
            prompts.append(splitted[0])
        else:
            prompts.append(queries[i])
            classes.append(queries[i])
    
    # Object detection
    image, w = get_image(image_idx)
    boxes, labels = get_bounding_boxes(prompts, image, k=k, threshold=threshold)
    boxes = [[int(x) for x in list(b)] for b in boxes]
    output["color"] = get_color_map(classes)
    output["color"] = {k: f"rgba{output['color'][k]}" for k in output['color']}
    for i in range(len(classes)):
        label = classes[i]
        output["predictions"][label] = []
        #output["color"][label] = COLORS[i%len(COLORS)]
    for i in range(len(boxes)):
        label = classes[labels[i]]
        output["predictions"][label].append(boxes[i])
    output["predictions"] = remove_redundant_boxes(output["predictions"], *image.size)
    
    # Radio input component and download button
    options = [{"label": l, "value": output["color"][l]} for l in output["color"]]
    style["display"] = "block"
    
    return json.dumps(output), options, options[0]["value"], style, w

@app.callback(
    Output('graph-picture', 'figure'),
    Output('detected2', 'data'),
    Input('detected', 'data'),
    Input('label', 'value'),
    State('detected2', 'data'),
    State('graph-picture', 'figure'),
    prevent_initial_call=True
)
def update_image(detected, default_color, detected2, fig):
    """
    Update the image.
    """
    output = json.loads(detected)
    image_idx = output["idx"]
    # Case of the absence of predictions, e.g. if the query is empty
    if len(output["predictions"]) == 0:
        return get_figure(image_idx), detected
    
    # Case of a change of label triggering this callback
    if detected == detected2:
        fig2 = go.Figure(fig)
        fig2.update_layout(
            newshape=dict(
                fillcolor="rgba(0,0,0,0)",
                line=dict(color=default_color, width=2))
        )
        return fig2, detected
    
    # Case of new predictions
    fig2 = get_figure(image_idx)
    default_color = None
    for label in output["predictions"]:
        color = output["color"][label]
        if default_color is None:
            default_color = color
        for box in output["predictions"][label]:
            fig2.add_shape(**get_annotation(*box, color))
    fig2.update_layout(
        dragmode="drawrect",
        newshape=dict(
            fillcolor="rgba(0,0,0,0)",
            line=dict(color=default_color, width=2))
    )
    return fig2, detected

@app.callback(
    Output('notification', 'children'),
    Output('notification', 'is_open'),
    Input('download-btn', 'n_clicks'),
    State('detected', 'data'),
    State('graph-picture', 'figure'),
    State('image-width', 'data'),
    prevent_initial_call=True
)
def save_annotation(n_clicks, detected, fig, w):
    """
    Download the annotation of the current image.
    """
    if SAVE_DISABLED:
        return "Save button disabled. Change SAVE_DISABLED to enable it", True
    
    if detected is not None:
        output = json.loads(detected)
        image_idx = output["idx"]
        color_map = output["color"]
        inverse_color_map = {color_map[k]: k for k in color_map}
        boxes = []
        for shape in fig["layout"]["shapes"]:
            x0, y0, x1, y1 = [int(shape[k]) for k in ["x0", "y0", "x1", "y1"]]
            label = inverse_color_map[shape["line"]["color"]]
            box = [x0, y0, x1-x0, y1-y0]
            box = [w*x//800 for x in box]
            boxes.append({"bbox": box, "category": label})
        num_boxes = len(boxes)
        with output_folder.get_writer(paths[image_idx] + ".json") as w:
            w.write(bytes(json.dumps(boxes), 'utf-8'))
        return f"{num_boxes} bounding box{'es' if num_boxes > 1 else ''} saved", True
    else:
        return "0 bounding box saved", True