import dataiku
import os
import logging
import tempfile
import io
import base64
from functools import lru_cache

from PIL import Image
import torch
import transformers
from qwen_vl_utils import process_vision_info
import outlines

import dash
from dash import dcc, html, Input, Output, State, ctx
import dash_bootstrap_components as dbc

MAX_TOKENS_OCR = 200

sample_docVQA = dataiku.Folder("sUGOSMqr")
files = sample_docVQA.list_paths_in_partition()

CUDA_AVAILABLE = torch.cuda.is_available()

if CUDA_AVAILABLE:
    # Define the pre-trained model to use
    model_name = "Qwen/Qwen2-VL-7B-Instruct"
    processor = transformers.AutoProcessor.from_pretrained(model_name)

    # Use quantization for efficient processing
    bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)

    # Load the Qwen-VL model with outlines for (if needed) output
    model = outlines.models.transformers_vision(
        model_name,
        model_class=transformers.Qwen2VLForConditionalGeneration,
        processor_class=transformers.AutoProcessor,
        device="cuda",
        model_kwargs={
            "quantization_config": bnb_config,
            "device_map": "auto",
        },
    )

STYLE = {
    "group": {"margin": "20px 0px"},
    "container": {
        "display": "flex",
        "justify-content": "space-evenly",
        "margin": "10px auto",
        "max-width": "600px",
        "list-style-type": "none", 
    },
    "page": {"margin": "auto", "max-width": "800px", "text-align": "center"},
    "image": {"width": "300px", "height": "auto"},
    "result_box": {
        "border": "1px solid #ddd",
        "padding": "10px",
        "white-space": "pre-wrap",
        "background-color": "#f9f9f9",
    },
    "task_input": {
        "margin": "10px auto 0px auto",
        "max-width": "600px",
        "display": "none"
    },
    "submit_box": {
        'color': 'white',  # Couleur du texte
        'padding': '10px 24px',  # Ajuster le padding
        'border-radius': '12px',  # Coins arrondis
        'border': 'none',  # Supprimer la bordure
        'font-size': '16px',  # Taille du texte
        'cursor': 'pointer'  # Changer le curseur lorsqu'on passe la souris dessus
    },
}

@lru_cache(maxsize=20)
def encode_image(folder, image_path):
    """
    Encode an image in base 64.
    """
    with folder.get_download_stream(image_path) as stream:
        return base64.b64encode(stream.read()).decode('utf-8')

def get_image(image_idx):
    """
    Provide the image corresponding to a certain index
    """
    path = files[image_idx]
    with sample_docVQA.get_download_stream(path) as f:
        buf = io.BytesIO(f.read())
        image = Image.open(buf).convert("RGB")
    w, h = image.size
    return image.resize((800, h*800//w))


def perform_task_qwen(task, image, user_query=""):
    if task == 'OCR':
        
        prompt = """You are an OCR processor. Transcribe the text given in a way that I can copy paste it in a notepad. Do not write markdown or LateX. Always provide an answer even if the picture looks low quality."""
        
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },
                        {"type": "text", 
                         "text": prompt},
                ],
            }
            ]

        image_inputs = process_vision_info(messages)[0]  # Prepare vision inputs
        text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True)
        description_generator = outlines.generate.text(model, sampler = outlines.samplers.GreedySampler())
        output = description_generator(text, image_inputs, max_tokens=MAX_TOKENS_OCR)
        len_output = processor.tokenizer(output, return_tensors="pt")['input_ids'].shape[-1]
        
        if len_output == MAX_TOKENS_OCR:
            return output + '...[Output truncated at 200 tokens for performance reasons]'
        
        return output

    elif task == "KIE":
        prompt = user_query + " Return a json."
        json_grammar = outlines.grammars.json
        
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },
                        {"type": "text", 
                         "text": prompt},
                ],
            }
        ]

        image_inputs = process_vision_info(messages)[0]  # Prepare vision inputs
        text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True)
        description_generator = outlines.generate.cfg(model, json_grammar, sampler = outlines.samplers.GreedySampler())
        output = description_generator(text, image_inputs)
        return output

    elif task == "VQA":
        prompt = "Answer the question. Do not write a full sentence, just provide a value. " + user_query
        
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                    },
                        {"type": "text", 
                         "text": prompt},
                ],
            }
        ]

        image_inputs = process_vision_info(messages)[0]  # Prepare vision inputs
        text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True)
        description_generator = outlines.generate.text(model, sampler = outlines.samplers.GreedySampler())
        output = description_generator(text, image_inputs)
        return output

search_icon = html.Span(html.I(className="bi bi-search"))

# Inputs
inputs = html.Div(
    [
        html.Div(
            dbc.Pagination(
                id="image-idx",
                active_page=1,
                max_value=len(files),
                first_last=True,
                previous_next=True,
            ),
            style=STYLE["container"],
        ),
        html.Div(
            [
                dbc.Select(
                    id="select-task",
                    options=[
                        {"label": "OCR", "value": "OCR"},
                        {"label": "Visual Question Answering", "value": "VQA"},
                        {"label": "Key Information Extraction", "value": "KIE"},
                    ],
                    placeholder="Select a task",
                    value="",
                    style=STYLE["container"],
                ),
                dbc.InputGroup(
                    [
                        dbc.Input(
                            id='task-input',
                            placeholder="Type your query here",
                            type="text",
                            minLength=0
                        ),
                        dbc.Button(search_icon, id='submit-task', n_clicks=0),
                    ],
                    id='task-input-group',
                    style=STYLE["task_input"],
                ),
            ],
            style=STYLE["group"],
        ),
    ]
)

# Outputs
outputs = html.Div(
    [
        html.Img(
            id="image-display",
            src="data:image/jpeg;base64,{}".format(
                encode_image(sample_docVQA, files[0])
            ),
            style=STYLE["image"],
        ),
        dbc.Spinner(
            html.Div(id="output-container", style=STYLE["group"]), color="primary"
        ),
        html.Div(id="task-result", style=STYLE["group"]),
    ]
)

app.title = "Document AI"

app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

app.config.suppress_callback_exceptions = True

app.layout = html.Div(
    [
        inputs,
        outputs,
    ],
    style=STYLE["page"]
)

@app.callback(
    [
        Output("image-display", "src"),
        Output("output-container", "children"),
        Output("select-task", "value"),
        Output('task-input-group', 'style'),
        Output("task-result", "children"),
        Output('task-input', 'value'),  
    ],
    [
        Input("image-idx", "active_page"),
        Input("select-task", "value"),
        Input("submit-task", "n_clicks"),
    ],
    [
        State("task-input", "value"),
        State("image-idx", "active_page"),
        State("select-task", "value"),
    ],
    prevent_initial_call=True
)
def combined_callback(image_page, task_value, task_n_clicks,
                      task_input_value, image_state, task_state):

    # Initialize outputs with dash.no_update
    image_src = dash.no_update
    output_container = dash.no_update
    select_task_value = dash.no_update
    task_result = dash.no_update
    task_input_style = {**STYLE["task_input"], 'display': 'none'}
    task_input_value_output = dash.no_update

    ctx = dash.callback_context

    if not ctx.triggered:
        raise dash.exceptions.PreventUpdate
    else:
        triggered_input = ctx.triggered[0]['prop_id'].split('.')[0]

    if triggered_input == "image-idx":
        image_index = image_page - 1
        image_path = files[image_index]
        image_src = "data:image/jpeg;base64,{}".format(encode_image(sample_docVQA, image_path))

        task_result = None
        output_container = None

        select_task_value = None

        # Hide input groups
        task_input_style['display'] = 'none'

        task_input_value_output = ''
    
    elif not CUDA_AVAILABLE:
        task_result = "No GPU available on this instance. Please note that this webapp is not live on the Dataiky Project Gallery. If you want to use it, please download this project and add it to your own Dataiku instance."
        
    elif triggered_input == "select-task" and task_value:
        # Task has changed
        # Reset task-result
        task_result = None

        if task_value == "OCR":
            # Perform OCR processing here (replace with actual OCR code)
            image_index = image_state - 1
            image = get_image(image_index)
            ocr_text = perform_task_qwen(task_value, image)
            output_container = html.Div([
                html.Div(ocr_text, style=STYLE["result_box"])
            ])

            # Hide input groups
            task_input_style['display'] = 'none'


        elif task_value == "VQA" or "KIE":
            # Show VQA input group
            task_input_style['display'] = 'flex'

            # Clear output_container
            output_container = None

            task_input_value_output = ''


    elif triggered_input == "submit-task" and task_n_clicks > 0 and task_state == "VQA":
        # Process VQA task
        image_index = image_state - 1
        image = get_image(image_index)
        generated_text = perform_task_qwen(task_state, image, task_input_value)
        task_result = html.Div([
            html.Div(generated_text, style=STYLE["result_box"])
        ])

        task_input_value_output = dash.no_update
        task_input_style = dash.no_update


    elif triggered_input == "submit-task" and task_n_clicks > 0 and task_state == "KIE":
        # Process KIE task
        image_index = image_state - 1
        image = get_image(image_index)
        generated_text = perform_task_qwen(task_state, image, task_input_value)
        task_result = html.Div([
            html.Div(generated_text, style=STYLE["result_box"])
        ])

        task_input_value_output = dash.no_update
        task_input_style = dash.no_update


    else:
        # No relevant input triggered, do nothing
        raise dash.exceptions.PreventUpdate

    return (
        image_src,
        output_container,
        select_task_value,
        task_input_style,
        task_result,
        task_input_value_output,
    )

