import dataiku
import logging
import io
import base64
from functools import lru_cache
import gzip
import json

from colpali_engine.models import ColPali, ColPaliProcessor
import transformers
import torch
import PIL
from colpali_engine.interpretability import (
    get_similarity_maps_from_embeddings,
    plot_similarity_map,
)

import dash
from dash import dcc, html, Input, Output, State, ctx
import dash_bootstrap_components as dbc

CUDA_AVAILABLE = torch.cuda.is_available()
if CUDA_AVAILABLE:
    processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.2")
    model = ColPali.from_pretrained(
        "vidore/colpali-v1.2",
        torch_dtype=torch.bfloat16,
        quantization_config=transformers.BitsAndBytesConfig(load_in_4bit=True)
    ).eval().to("cuda")

index_folder = dataiku.Folder("8jv9ZTHM")

with index_folder.get_download_stream('/doc_ids_to_file_names.json.gz') as folder_stream:
    buffer = io.BytesIO(folder_stream.read())
    with gzip.GzipFile(fileobj=buffer, mode='rb') as gz:
        id_to_filename = json.load(gz)
        
filename_to_id = {v.split("/")[-1]: k for k, v in id_to_filename.items()}

with index_folder.get_download_stream('/embed_id_to_doc_id.json.gz') as folder_stream:
    buffer = io.BytesIO(folder_stream.read())
    with gzip.GzipFile(fileobj=buffer, mode='rb') as gz:
        embed_id_to_doc_id = json.load(gz)   
        
collection_index = [x for x in index_folder.list_paths_in_partition() if "collection" in x]
sorted_collection = sorted(collection_index, key=lambda x: int(x.split('/')[-1].split('.')[0]))

collection = []
for file_path in sorted_collection:
    with index_folder.get_download_stream(file_path) as folder_stream:
        buffer = io.BytesIO(folder_stream.read())
        with gzip.GzipFile(fileobj=buffer, mode='rb') as gz_file:
            collection.extend(json.load(gz_file).values())

embedding_index = [x for x in index_folder.list_paths_in_partition() if "embeddings" in x]
sorted_embedding = sorted(embedding_index, key=lambda x: int(x.split('_')[-1].split('.')[0]))

embeddings = []        
for file_path in sorted_embedding:
    with index_folder.get_download_stream(file_path) as folder_stream:
        buffer = io.BytesIO(folder_stream.read())
        embeddings.extend(torch.load(buffer)) 
    
alpha_values = {
    1: "Low contrast",
    2: "Medium contrast",
    5: "High contrast"
}

def similarity_images(filename, page, query):
    """
    Compute the similarity heatmaps for a page and a query.
    """
    embedding_image = get_embedding_from_filename_and_page(filename, page).unsqueeze(0)
    process_query = processor.process_queries([query]).to(model.device)
    with torch.no_grad():
        query_embedding = model(**process_query)
    
    image_b64 = get_image_from_filename_and_page(filename, page)
    
    image_data = base64.b64decode(image_b64)

    image = PIL.Image.open(io.BytesIO(image_data))
    
    n_patches = processor.get_n_patches(image_size=image.size, patch_size=model.patch_size)

    # Get the tensor mask to filter out the embeddings that are not related to the image
    image_mask = processor.get_image_mask(processor.process_images([image]))
    
    # Generate the similarity maps
    batched_similarity_maps = get_similarity_maps_from_embeddings(
        image_embeddings=embedding_image.to(model.device),
        query_embeddings=query_embedding.to(model.device),
        n_patches=n_patches,
        image_mask=image_mask,
    )
    
    # Get the similarity map for our (only) input image
    similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)
    similarity_map_mean = similarity_maps.mean(dim=0)
    similarity_map_mean = (similarity_map_mean - similarity_map_mean.min()) / (similarity_map_mean.max() - similarity_map_mean.min())
    similarity_map_mean_alpha = {x: similarity_map_mean**x for x in alpha_values}
    images_b64 = {}
    for alpha in similarity_map_mean_alpha:
        fig, ax = plot_similarity_map(
            image=image,
            similarity_map=similarity_map_mean_alpha[alpha],
            figsize=(8, 8),
            show_colorbar=False,
        )

        with io.BytesIO() as buf:
            fig.savefig(buf, format="JPEG")
            buf.seek(0)  

            images_b64[alpha] = base64.b64encode(buf.getvalue()).decode('utf-8')
        
    return images_b64


def get_image_from_filename_and_page(filename, page):
    page = int(page)
    doc_id = int(filename_to_id[filename])
    for embed_id, values in embed_id_to_doc_id.items():
        if values['doc_id'] == doc_id and values['page_id'] == page:
            result_id = int(embed_id)
    
    return collection[result_id]

def get_embedding_from_filename_and_page(filename, page):
    page = int(page)
    doc_id = int(filename_to_id[filename])
    for embed_id, values in embed_id_to_doc_id.items():
        if values['doc_id'] == doc_id and values['page_id'] == page:
            result_id = int(embed_id)
    
    return embeddings[result_id]
    
    
app.title = "Interpretability for ColPali"

app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

app.config.suppress_callback_exceptions = True
search_icon = html.Span(html.I(className="bi bi-search"))

STYLE = {
    'container': {
        'display': 'flex',
        'flexDirection': 'column',
        'alignItems': 'center',
        'justifyContent': 'center',
        'max-width': '800px',
        'margin': 'auto'
    },
    'select_container': {
        'display': 'flex',
        'flexDirection': 'row',
        'justifyContent': 'center',
        'alignItems': 'center',
        'width': '95%',
        'margin': '3px 0',
    },
    'select': {
        'flex': '1',
        'margin': '3px 1%',
    },
    'input_alpha_container': {
        'display': 'flex',
        'flexDirection': 'row',
        'alignItems': 'center',
        'justifyContent': 'center',
        'width': '95%',
        'margin': '3px 0',
    },
    'input_group': {
        'flex': '3',
        'margin': '3px 1%',
        'display': 'flex',
    },
    'select_alpha': {
        'flex': '1'
    },
    'image': {
        'width': '600px',
        'height': 'auto'
    },
}

selected_filename = list(filename_to_id.keys())[0]
page_ids = [item['page_id'] for item in embed_id_to_doc_id.values() if item['doc_id'] == int(filename_to_id[selected_filename])]
page_select_options = [{'label': f"Page {page_id}", 'value': page_id} for page_id in page_ids]

app.layout = html.Div([
    html.Div("" if CUDA_AVAILABLE else "No GPU available. This webapp requires a GPU. Please note that this webapp cannot be used on the Dataiku project gallery. If you want to test it, download the project and import it on your own Dataiku instance.", style=STYLE['select_container']),
    html.Div([
        dbc.Select(
            id='filename-select',
            options=[{'label': fname, 'value': fname} for fname in filename_to_id],
            value=selected_filename,
            style=STYLE['select']
        ),
        html.Div(
            id='page-select-container',
            children=[
                dbc.Select(
                    id='page-select',
                    options=[{'label': f"Page {page_id}", 'value': page_id} for page_id in page_ids],
                    value=page_ids[0],
                    style=STYLE['select']
                )
            ]
        ),
    ], style=STYLE['select_container']),
    
    html.Div([
        dbc.InputGroup(
            [
                dbc.Input(
                    id='query-input',
                    placeholder="Type your query here",
                    type="text",
                ),
                dbc.Button(search_icon, id='submit-query', n_clicks=0, className='h-100'),
            ],
            id='task-input-group',
            style=STYLE['input_group'],
        ),
        dbc.Select(
            id='alpha-select',
            options=[{'label': alpha_values[a], 'value': a} for a in alpha_values],
            value=2,
            style=STYLE['select_alpha']
        ),
    ], style=STYLE['input_alpha_container']),
        dbc.Spinner(
            html.Img(
            id="image-display",
            src=f"data:image/jpeg;base64,{get_image_from_filename_and_page(list(filename_to_id.keys())[0], 1)}",
            style=STYLE['image']
    ),
            color="primary"
                ),
    dcc.Store(id='images-store', data=None),
    dcc.Store(id='heatmaps_computed', data=None)
], style=STYLE['container'])

@app.callback(
    Output('page-select', 'options'),
    Output('page-select', 'value'),  
    Output('image-display', 'src'),
    Output('images-store', 'data'),
    Output('heatmaps_computed', 'data'),  
    Input('filename-select', 'value'),
    Input('page-select', 'value'),
    Input('submit-query', 'n_clicks'),
    Input('alpha-select', 'value'),
    State('images-store', 'data'),
    State('query-input', 'value'),
    State('heatmaps_computed', 'data'),  
)
def combined_callback(selected_filename, selected_page, query_n_clicks, alpha_value, images_store_data, query_value, heatmaps_computed):
    ctx = dash.callback_context
    if not ctx.triggered:
        raise dash.exceptions.PreventUpdate
    else:
        triggered_input = ctx.triggered[0]['prop_id'].split('.')[0]
    
    page_select_options = dash.no_update
    page_select_value = selected_page
    images_store_data_output = images_store_data if images_store_data is not None else {}
    heatmaps_computed = heatmaps_computed if heatmaps_computed is not None else False
        
    if triggered_input == "filename-select":
        page_ids = [item['page_id'] for item in embed_id_to_doc_id.values() if item['doc_id'] == int(filename_to_id[selected_filename])]
        page_select_options = [{'label': f"Page {page_id}", 'value': page_id} for page_id in page_ids]
        page_select_value = 1
        image = get_image_from_filename_and_page(selected_filename, 1)
        image_display_src = f"data:image/jpeg;base64,{image}"
        heatmaps_computed = False
        
    elif triggered_input == "page-select":
        image = get_image_from_filename_and_page(selected_filename, selected_page)
        image_display_src = f"data:image/jpeg;base64,{image}"
        heatmaps_computed = False
        
    elif triggered_input == "submit-query" and query_n_clicks > 0 and CUDA_AVAILABLE:
        images_store_data_output = similarity_images(selected_filename, selected_page, query_value)
        image_display_src = f"data:image/jpeg;base64,{images_store_data_output[int(alpha_value)]}"
        alpha_value_output = alpha_value
        heatmaps_computed = True
        
    elif triggered_input == "alpha-select" and heatmaps_computed:
        image_display_src = f"data:image/jpeg;base64,{images_store_data_output[alpha_value]}"

    return (
        page_select_options,
        page_select_value,
        image_display_src,
        images_store_data_output,
        heatmaps_computed
    )
