import gradio as gr
import os
import re
import time
import dataiku
import base64
from io import BytesIO

browser_path = os.getenv("DKU_CODE_STUDIO_BROWSER_PATH_7860")
# replacing env var keys in browser_path with their values
env_var_pattern = re.compile(r'(\${(.*)})')
env_vars = env_var_pattern.findall(browser_path)
for env_var in env_vars:
    browser_path = browser_path.replace(env_var[0], os.getenv(env_var[1], ''))

#LLM_ID = "openai:openai:gpt-4-vision-preview" #"openai:openai:gpt-3.5-turbo-16k"
LLM_ID = dataiku.get_custom_variables()["LLM_id"]

client = dataiku.api_client()
project = client.get_default_project()
llm = project.get_llm(LLM_ID)

def encode_image(image_path):
    """
    Encode an image in base 64.
    """
    with open(image_path, 'rb') as img:
        encoded_string = base64.b64encode(img.read())
    return encoded_string.decode('utf-8')


def get_messages(question, images, history):
    """
    Build the messages sent to the multimodal LLM.
    """
    messages = []
    
    # System message: instruction for the multimodal LLM
    messages.append({
        "role": "system",
        "parts": [
            {
                "type": "TEXT",
                "text": "You are a helpful assistant. Concisely answer the user's question based on the provided facts. If you don't know, just say you don't know."
            }
        ]
    })

    if history:
        for msg in history:
            history_prompt = msg[0]
            history_answer = msg[1]
            if isinstance(history_prompt, tuple):
                messages.append({
                    "role": "user",
                    "parts": [ 
                        {
                            "type": "IMAGE_INLINE",
                            "inlineImage": encode_image(history_prompt[0])
                        }
                    ]
                })
            elif history_prompt:
                    messages.append({
                    "role": "user",
                    "parts": [
                        {
                            "type": "TEXT",
                            "text": f"Answer the following question: {history_prompt}."
                        }
                    ]
                })
            if history_answer:
                    messages.append({
                    "role": "assistant",
                    "parts": [
                        {
                            "type": "TEXT",
                            "text": f"{history_answer}."
                        }
                    ]
                })

    
    # User message: the question posed by the user
    if question:
        messages.append({
            "role": "user",
            "parts": [
                {
                    "type": "TEXT",
                    "text": f"Answer the following question: {question}."
                }
            ]
        })
    
    # Iterate over each chunk with metadata
    if images:
        for image in images:
            messages.append({
                "role": "user",
                "parts": [ 
                    {
                        "type": "IMAGE_INLINE",
                        "inlineImage": encode_image(image)
                    }
                ]
            })
    
    # Return the list of messages
    return messages

def chat_function(message, history):
    text = message["text"]
    images = message["files"]
    completion = llm.new_completion()
    completion.cq["messages"] = get_messages(text, images, history) 
    answer = ""
    for chunk in completion.execute_streamed():
        answer_stream = chunk.data.get('text')
        if answer_stream:
            answer = answer + str(answer_stream)
            yield answer

theme = gr.themes.Soft(
    secondary_hue="indigo",
    neutral_hue="gray",
    spacing_size="sm",
    radius_size="sm",
)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
demo = gr.ChatInterface(fn=chat_function, multimodal=True, textbox=chat_input, theme=theme)

# WARNING: make sure to use the same params as the ones defined below when calling the launch method,
# otherwise you app might not be responding!
demo.launch(server_port=7860, root_path=browser_path)
