import dataiku
import dash
from dash import dcc, html, ctx
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from PIL import Image
import random
import base64

app.config.external_stylesheets = [
    "https://fonts.googleapis.com/css2?family=Outfit:wght@100;200;300;400;500;600;700;800;900&display=swap",
    dbc.themes.ZEPHYR
]
font_family = "Outfit"

# Set the path to the folder containing your images
image_folder = dataiku.Folder("rwNI6KPi")

# Read the dataset containing the captions
captions_dataset = dataiku.Dataset("images_captions")
captions_df = captions_dataset.get_dataframe()

# Initiliaze the app by picking a random image from the folder
nb_images = len(image_folder.list_paths_in_partition())
random_image = random.randint(0, nb_images - 1)
image_filename = image_folder.list_paths_in_partition()[random_image]

# Set the path to the audio files
audio_folder = dataiku.Folder("ShQW2Upa")
audio_folder_path = audio_folder.get_path()

# Get the appropriate caption for the first image
caption = captions_df[captions_df["images"] == image_filename]["captions"].values[0]

# Get the audio files for the first image
audio_path = image_filename + '.mp3'
with audio_folder.get_download_stream(audio_path) as f:
    encoded_audio = base64.b64encode(f.read())

# Using base64 encoding and decoding
base_width = 300
with image_folder.get_download_stream(path=image_filename) as stream:
    im = Image.open(stream)
    wpercent = (base_width / float(im.size[0]))
    hsize = int((float(im.size[1]) * float(wpercent)))
    im = im.resize((base_width, hsize), Image.Resampling.LANCZOS)

# Get a list of all image files in the folder
image_files = image_folder.list_paths_in_partition()


app.layout = html.Div([
        html.H2("Image captioning with GPT-4V", style={'textAlign': 'center'}),
        html.Img(src=im, id='image-display', style={'margin-top': 30}),
        html.Div(html.Audio(
            src='data:audio/mpeg;base64,{}'.format(encoded_audio.decode()), 
            controls=True,
            style={'height': '30px'},
            id= "audio-play"
        ), style={'margin-top': 20}),
        html.P(caption, id='image-info', style={'margin-top': 20}),
        html.Div([
            dbc.Button("Next", id='next-button', n_clicks=0, style={'margin-top': 10, 'margin-bottom': 30}),
        ]),
        
    ], style={
        "margin": "auto",
        "text-align": "center",
        "max-width": "700px",
        "font-family": font_family,
        "justify-content": "center",  # Center horizontally
        "align-items": "center",      # Center vertically
        "height": "100vh",
        "padding": "20px"
    }
)

@app.callback(
    [Output('image-display', 'src'),
     Output('image-info', 'children'),
     Output('audio-play', 'src')],
    [Input('next-button', 'n_clicks')],
    prevent_initial_call=True
)
def update_image_display(next_clicks):
    # Determine which button was clicked
    button_id = ctx.triggered_id
    if button_id == 'next-button':
        random_image = random.randint(0, nb_images - 1)
        image_filename = image_folder.list_paths_in_partition()[random_image]
        caption = captions_df[captions_df["images"] == image_filename]["captions"].values[0]
        with image_folder.get_download_stream(path=image_filename) as stream:
            im = Image.open(stream)
            wpercent = (base_width / float(im.size[0]))
            hsize = int((float(im.size[1]) * float(wpercent)))
            im = im.resize((base_width, hsize), Image.Resampling.LANCZOS)
        audio_path = image_filename + '.mp3'
        with audio_folder.get_download_stream(audio_path) as f:
            encoded_audio = base64.b64encode(f.read())
        audio_src = 'data:audio/mpeg;base64,{}'.format(encoded_audio.decode())
        return im, caption, audio_src



