#####################################################
################# USEFUL FUNCTIONS ##################
#####################################################

import os
import json
import base64
from PIL import Image
from io import BytesIO
from dash.exceptions import PreventUpdate
from lens import Lens, LensProcessor
import torch

from openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory


def answer_question(history_questions, history_answers, store_img, b64_images, model, client, openai_api_key, llm):
    """
    Generate a response to the last question based on previous conversation history and a specified model.

    Parameters:
    - history_questions (list): A list of previous questions in the conversation.
    - history_answers (list): A list of previous answers corresponding to the questions.
    - store_img (dict): A dictionary storing image information.
    - b64_images (list): A list of base64 encoded images.
    - model (str): A string indicating the model to use for generating responses. Possible values are "lens" or "gpt4v".
    - client (OpenAI client): A client to OpenAI endpoints
    - openai_api_key (str): An API key to OpenAI

    Returns:
    - answer (str): The generated response to the last question.

    Note:
    - This function supports two different models: "lens" and "gpt4v". 
    - For the "lens" model, it utilizes a conversational prompt template and a conversation memory to generate responses.
    - For the "gpt4v" model, it constructs a message object with user questions and images (if available) and sends it to the OpenAI API.
    - The function returns the generated response to the last question in the conversation history.
    """
    
    # Models configuration
    MODELS = {
        "lens": "gpt-3.5-turbo",
        "gpt4v": "gpt-4-vision-preview"
    }
    
    # If the model is "lens"
    if model=="lens":
        
        # Prepare a conversation prompt template including system messages, previous conversation history, and the current question
        prompt = ChatPromptTemplate(
            messages=[
                SystemMessagePromptTemplate.from_template("You are a helpful assistant."),
                MessagesPlaceholder(variable_name="chat_history"),
                HumanMessagePromptTemplate.from_template("{question}")
            ]
        )

        # Initialize conversation memory
        memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        
        # Loop through previous questions and answers to add them to the conversation memory
        for i in range(len(history_questions)-1):
            if type(history_questions[i]) != str and model != "gpt4v":
                
                # If the question is not a string and model is not "gpt4v", replace the question with a prompt about image captions
                prompt_captions = "You are provided with a series of captions from a specific image.\
                Here is the list of captions: {}. Don't mention the captions in the conversation.".format(store_img[history_questions[i]['props']['src']])
                history_questions[i] = prompt_captions
            
            # Add user question and corresponding answer to the conversation memory
            memory.chat_memory.add_user_message(history_questions[i])
            try:
                memory.chat_memory.add_ai_message(history_answers[i])
            except IndexError:
                pass

        # Create a conversation chain using the prompt and memory
        conversation_chain = LLMChain(
            llm=llm,
            prompt=prompt,
            verbose=True,
            memory=memory
        )

        # Generate a response to the last question in the history
        result = conversation_chain({"question": history_questions[-1]})
        answer = result["text"]
    
    # If the model is "gpt4v"
    elif model=="gpt4v":
        
        # Prepare a message object with user questions and base64 encoded images (if available)
        messages = {"role": "user", "content":[]}
               
        if history_questions:
            prompt_question = {
                    "type": "text",
                    "text": history_questions[-1]
            }
            messages["content"].append(prompt_question)
        if b64_images:
            for b64_image in b64_images:
                prompt_image = {
                        "type": "image_url",
                        "image_url": {
                            "url": "data:image/jpeg;base64,{}".format(b64_image),
                            "detail" : "low"
                            }
                }
                messages["content"].append(prompt_image)
                
        # Send the message to the OpenAI API to get completion for the conversation
        result = client.chat.completions.create(
            model=MODELS[model],
            messages=([messages]),
            temperature=0,
            max_tokens=300,
        )
        
        # Extract and return the generated answer from the API response
        answer = result.choices[0].message.content
    
    return answer

def get_captions(lens, processor, content, select):
    """
    Extract captions from an image using a specified model or return None if no model is selected.

    Parameters:
    - content (str): A base64 encoded image along with its content type.
    - select (str): A string indicating the model to use for caption extraction ("lens" or any other value).

    Raises:
    - PreventUpdate: If content is None.

    Returns:
    - captions (str or None): Captions extracted from the image if select is "lens", otherwise None.
    - content_string (str): The base64 encoded image content.
    """
    # Check if content is None
    if content is None: 
        raise PreventUpdate
    
    # Split the content into content type and content string
    content_type, content_string = content.split(',')
    
    # Decode the base64 encoded content string to obtain the image data
    stream = base64.b64decode(content_string)
    
    # Open the image using the Python Imaging Library (PIL)
    img = Image.open(BytesIO(stream))
    
    # If select is "lens", extract captions using the specified model
    if select == "lens":
        with torch.no_grad():
            samples = processor([img])
            output = lens(samples)
            captions = output["prompts"][0]
    else:
        # If select is not "lens", set captions to None
        captions = None
        
    return captions, content_string

def encode_image(image_path):
    """
    Encodes an image file to base64.

    Parameters:
    - image_path (str): Path to the image file.

    Returns:
    - str: Base64 encoded image data.
    """
    # Open the image file in binary mode
    with open(image_path, "rb") as image_file:
        # Read the image data and encode it to base64
        encoded_image = base64.b64encode(image_file.read()).decode('utf-8')
    # Return the base64 encoded image data
    return encoded_image

