import dataiku
import os
import io
import re
import logging
import json
import tempfile
import base64
import fitz
from PIL import Image, ImageDraw
import pytesseract
from project_utils import load_image, save_image


LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
project = dataiku.api_client().get_default_project()
llm = project.get_llm(LLM_ID)

# Set the language for OCR
LANGUAGE = "eng"
# Set the DPI (Dots Per Inch) for Tesseract OCR
TESSERACT_DPI = 200  # Default resolution used by `unstructured` for OCR
# Set the zoom factor for PDF to image conversion
zoom = 2
# Create a transformation matrix for scaling
mat = fitz.Matrix(zoom, zoom)

df = dataiku.Dataset("images").get_dataframe().set_index("index")
folder = dataiku.Folder("vOjkXoGz")
documents_folder = dataiku.Folder("Zcjq6uFf")

def extract_json(string):
    """
    Parse the output of the multimodal LLM (which is expected to be a JSON string).
    
    Parameters:
    string (str): The string output from the LLM, expected to contain JSON data.
    
    Returns:
    dict: A dictionary containing the parsed JSON data with default values for missing keys.
    """
    try:
        # Find the start and end positions of the JSON object within the string
        start = string.index('{')
        end = string.rindex('}') + 1
        try:
            # Load the JSON object from the string
            result = json.loads(string[start:end])
            # Ensure that 'type' and 'caption' keys exist in the result, setting default values if they don't
            for k in ["type", "caption"]:
                if k not in result:
                    result[k] = ""
            return result
        except json.JSONDecodeError:
            # Log an error if the JSON decoding fails
            logging.info(f"Error when decoding GPT4 response: {string[start:end]}")
            return {"type": "", "caption": ""}
    except ValueError:
        # Return default values if the JSON object cannot be found in the string
        return {"type": "", "caption": ""}

# Define the schema for the expected JSON structure
SCHEMA = """{
    "type": "object",
    "properties": {
        "type": {
            "type": "string",
            "enum": ["chart", "diagram", "map", "photograph", "illustration", "paragraph"],
            "description": "The type of the image outlined in red"
        },
        "caption": {
            "type": "string",
            "description": "The caption of the image outlined in red. It should be empty if the type is `photograph` or `illustration`"
        }
    },
    "required": ["type", "caption"]
}"""


# Define system prompts for image and table captioning
SYSTEM_PROMPT_IMAGE = """Classify the figure outlined in red.
If the figure includes some text, caption it. The caption should be precise enough so that end users interested in the content of the figure could retrieve the image through a search engine.
Return only a JSON string complying with the following schema:
""" + SCHEMA

SYSTEM_PROMPT_TABLE = """Caption the figure outlined in red.
The caption should be precise enough so that end users interested in the content of the figure could retrieve the image through a search engine.
Directly provide the caption, without any comment."""

def get_caption(img, classify=True):
    """
    Get a caption (and possibly an image class) from the multimodal LLM.
    
    Parameters:
    img (PIL.Image.Image): The input image.
    classify (bool): Whether to classify the image or directly caption it.
    
    Returns:
    tuple: A tuple containing the caption and the image type (if classified).
    """
    # Convert the image to a JPEG string
    buf = io.BytesIO()
    img.save(buf, format="JPEG")
    img_str = base64.b64encode(buf.getvalue()).decode('utf-8')
    
    # Define system and user messages for interaction with the multimodal LLM
    messages = [
        {
            "role": "system",
            "parts": [
                {
                    "type": "TEXT",
                    "text": SYSTEM_PROMPT_IMAGE if classify else SYSTEM_PROMPT_TABLE
                }
            ]
        },
        {
            "role": "user",
            "parts": [
                {
                    "type": "IMAGE_INLINE",
                    "inlineImage": img_str
                }
            ]
        }
    ]

    # Initialize a new completion with the LLM
    completion = llm.new_completion()
    completion.cq["messages"] = messages
    completion.settings["maxOutputTokens"] = 300
    completion.settings["temperature"] = 0
    # Execute the completion and get the response
    resp = completion.execute()

    # If classifying the image, extract the JSON output
    if classify:
        llm_output = extract_json(resp.text)
        # If the caption is empty, recursively call the function to directly caption the image
        if len(llm_output["caption"]) == 0:
            return get_caption(img, classify=False), ""
        return llm_output["caption"], llm_output["type"]
    else:
        # If directly captioning the image, return the caption and image type
        return resp.text, "table"


# Initialize an empty list to keep track of images to be removed
to_remove = []

# Iterate over each image in the DataFrame
for f in df.index:
    # Extract coordinates of the bounding box from the 'points' column
    x1, y1, _, _, x2, y2, _, _ = re.findall(r"[\d\.]+", df.at[f, "points"])
    x1, y1, x2, y2 = [float(x)/TESSERACT_DPI*72 for x in [x1, y1, x2, y2]]
    
    # Calculate dimensions of the bounding box
    width, height = x2 - x1, y2 - y1
    short_side, long_side = min(height, width), max(height, width)
    
    # If the bounding box is too small, add the image to the list of images to be removed
    if short_side < 100:
        to_remove.append(f)
        continue

    modified = False  # Flag to track if the image has been modified
    page_number = int(df.at[f, "page"])  # Get the page number of the image
    source_filename = df.at[f, "filename"]  # Get the filename of the corresponding PDF

    # Load the corresponding PDF into a temporary directory
    with tempfile.TemporaryDirectory() as temp_dir:
        with documents_folder.get_download_stream(source_filename) as stream:
            filepath = os.path.join(temp_dir, source_filename)
            with open(filepath, "wb") as f2:
                f2.write(stream.read())
            pdf_file = fitz.open(filepath)

    page = pdf_file[page_number-1]  # Get the page containing the image

    # Create a picture with the initial image and its surroundings
    width = x2 - x1
    height = y2 - y1
    dx1 = min(width/5, x1)
    dx2 = min(width/5, page.rect.width - x2)
    dy1 = min(height/5, y1)
    dy2 = min(height/5, page.rect.height - y2)

    # Extract the image region from the PDF page and highlight it with a red rectangle
    img = page.get_pixmap(matrix=mat, clip=fitz.Rect(x1-dx1, y1-dy1, x2+dx2, y2+dy2))
    img = Image.frombytes("RGB", [img.width, img.height], img.samples)
    draw = ImageDraw.Draw(img)
    draw.rectangle(((zoom*dx1-2, zoom*dy1-2), (zoom*(dx1+x2-x1)+4, zoom*(dy1+y2-y1)+4)), fill=None, width=2, outline="red")

    img2 = load_image(folder, f)  # Load the image from the images folder

    try:
        # Rotate the images depending on the orientation detected for the text
        orientation = pytesseract.image_to_osd(img, output_type=pytesseract.Output.DICT)["orientation"]
        if orientation != 0:
            img = img.rotate(orientation, Image.NEAREST, expand=1)
            img2 = img2.rotate(orientation, Image.NEAREST, expand=1)
            # Update the image content after rotation
            df.at[f, "content"] = pytesseract.image_to_string(img2, lang=LANGUAGE)
            logging.info(f"New content for image {f}: {df.at[f, 'content']}")
            modified = True
    except pytesseract.TesseractError:
        logging.info(f"PyTesseract error when processing {f}")

    # Get a caption (and possibly an image class) for the image
    caption, image_type = get_caption(
        img,
        classify=df.at[f, "category"] == "Image"
    )
    df.at[f, "caption"] = caption
    df.at[f, "category"] = image_type
    logging.info(f"Caption for image {f}: {df.at[f, 'caption']}")

    # Resize the image if its dimensions exceed those allowed with GPT-4V
    height, width = img2.size
    short_side, long_side = min(height, width), max(height, width)
    if short_side > 768 or long_side > 2000:
        ratio = min(768/short_side, 2000/long_side)
        img2 = img2.resize((int(ratio*height), int(ratio*width)))
        modified = True

    if modified:
        save_image(folder, f, img2)  # Save the modified image to the images folder

# Drop images with indices stored in the to_remove list and reset the index of the DataFrame
df = df.drop(to_remove).reset_index()

# Filter the DataFrame to get only tables and reset the index
tables_df = df[df["category"] == "table"]
del tables_df["category"]

figures_df = df[df["category"] != "table"]
dataiku.Dataset("figures2").write_with_schema(figures_df)
dataiku.Dataset("tables2").write_with_schema(tables_df)