# -*- coding: utf-8 -*-
import dataiku
from PIL import Image
import io
import math
import numpy as np
import cv2
import matplotlib.pyplot as plt
from utils import load_image

# Input: OCR folder containing images and a dataset with layout analysis results
folder = dataiku.Folder("FAUwCSvY")

df = dataiku.Dataset("output_region_unstructured").get_dataframe()

# PART 1: Layout analysis and cropping elements, saving them in a folder "cropped_region_OCR"
# --------------------------------------------------------------------------------------------

# Destination folder for the cropped images
cropped_region_OCR = dataiku.Folder("Tz1CoYYW")


def correct_skew(image, delta=1, limit=5):
    """
    Corrects skew in the image using Hough Line Transformation to detect lines and rotate the image accordingly.

    Args:
        image: PIL image object.
        delta: Step size for the angle.
        limit: Maximum allowable skew angle.

    Returns:
        Image_final: The deskewed PIL image.
    """
    # Convert to numpy array
    image_array = np.array(image)

    # Convert to grayscale if necessary
    if len(image_array.shape) == 2:  # Image is already grayscale
        gray = image_array
    elif len(image_array.shape) == 3:  # Convert to grayscale
        gray = cv2.cvtColor(image_array, cv2.COLOR_BGR2GRAY)

    # Detect edges
    edges = cv2.Canny(gray, 50, 150, apertureSize=3)

    # Detect lines using Hough Line Transformation
    lines = cv2.HoughLinesP(
        edges, 1, np.pi / 180, 100, minLineLength=100, maxLineGap=10
    )

    if lines is not None:
        angles = []
        for line in lines:
            x1, y1, x2, y2 = line[0]
            angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
            angles.append(angle)

        # Calculate median angle of the detected lines
        median_angle = np.median(angles)

        # Compute new size for rotated image to fit within bounds
        height, width = image_array.shape[:2]
        center = (width // 2, height // 2)
        rot_angle = median_angle
        cos_angle = abs(math.cos(math.radians(rot_angle)))
        sin_angle = abs(math.sin(math.radians(rot_angle)))
        new_width = int((height * sin_angle) + (width * cos_angle))
        new_height = int((height * cos_angle) + (width * sin_angle))

        # Rotation matrix and warp affine transformation
        rot_mat = cv2.getRotationMatrix2D(center, rot_angle, 1.0)
        rot_mat[0, 2] += (new_width / 2) - center[0]
        rot_mat[1, 2] += (new_height / 2) - center[1]

        # Apply rotation
        rotated = cv2.warpAffine(
            image_array, rot_mat, (new_width, new_height), flags=cv2.INTER_LINEAR
        )

        # Convert back to PIL image
        Image_final = Image.fromarray(rotated)
    else:
        Image_final = image  # If no lines are detected, return the original image

    return Image_final


# Counter to keep track of cropped images for each file
image_path_counter = {}

# Iterate through each row in the DataFrame (bounding box data)
for index, row in df.iterrows():
    image_path = row["file_path"]

    # Keep track of how many crops have been made for each file
    if image_path in image_path_counter:
        image_path_counter[image_path] += 1
    else:
        image_path_counter[image_path] = 1

    # Load the image from the folder
    image = load_image(folder, image_path)

    # Get coordinates for cropping
    x1, y1, x2, y2 = row["x1"], row["y1"], row["x2"], row["y2"]

    # Crop the image based on the coordinates
    cropped_image = image.crop((x1, y1, x2, y2))

    # Correct skew in the cropped image
    image = correct_skew(cropped_image)

    # Generate a new image path for the cropped image
    new_image_path = (
        f"{image_path.split('/')[-1][:-4]}/cropped_{image_path_counter[image_path]}.png"
    )

    # Save the cropped image in a buffer and upload to the Dataiku folder
    with io.BytesIO() as buf:
        image.save(buf, format="PNG")  # Save the image in PNG format
        buf.seek(0)  # Move to the beginning of the buffer
        image_data = buf.read()  # Read the binary data

    # Upload the image data to the folder
    cropped_region_OCR.upload_data(new_image_path, image_data)


# PART 2: Drawing layout analysis bounding boxes directly on the images, saving the results
# --------------------------------------------------------------------------------------------

# Destination folder for images with bounding boxes drawn on them
bbox_output = dataiku.Folder("NRjaWtPA")


def save_unst_predictions(file, df_Unst):
    """
    Draws bounding boxes on the image based on the layout analysis results and saves the annotated image.

    Args:
        file: The image file to annotate.
        df_Unst: The DataFrame containing bounding box information (x1, y1, x2, y2) and category labels.
    """
    # Filter the DataFrame to get bounding boxes for the current file
    filtered_df_unst = df_Unst[df_Unst["file_path"] == file]

    labels_Unst = filtered_df_unst["type"]  # Get the labels/categories
    boxes_Unst = filtered_df_unst[
        ["x1", "y1", "x2", "y2"]
    ].values  # Get the bounding boxes

    # Load the image from the folder
    pil_img = load_image(folder, file)

    # Define colors for each category (in RGB)
    doclaynet_colors_rgb = {
        "Title": (31, 119, 180),
        "Footnote": (255, 127, 14),
        "Caption": (44, 160, 44),
        "Formula": (214, 39, 40),
        "Table": (148, 103, 189),
        "Picture": (140, 86, 75),
        "Page-header": (227, 119, 194),
        "Page-footer": (127, 127, 127),
        "Section-header": (188, 189, 34),
        "List-item": (23, 190, 207),
        "Text": (174, 199, 232),
    }

    # Plot the image with bounding boxes
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()

    # Draw each bounding box
    for label, (xmin, ymin, xmax, ymax) in zip(
        labels_Unst.tolist(), boxes_Unst.tolist()
    ):
        # Convert the RGB color to [0,1] scale for Matplotlib
        color = tuple([c / 255 for c in doclaynet_colors_rgb[label]])

        # Add bounding box
        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin),
                xmax - xmin,
                ymax - ymin,
                fill=False,
                color=color,
                linewidth=1.5,
            )
        )
        # Add label text
        ax.text(
            xmin, ymin, f"{label}", fontsize=15, bbox=dict(facecolor="red", alpha=0.5)
        )

    plt.axis("off")  # Hide axes

    # Save the annotated image to a buffer and upload to the folder
    with io.BytesIO() as buf:
        plt.savefig(
            buf, format="PNG", bbox_inches="tight", pad_inches=0
        )  # Save plot as PNG
        buf.seek(0)  # Rewind the buffer
        image_data = buf.read()  # Read binary data

    plt.close()  # Close the plot to free memory

    # Upload the annotated image to the folder
    bbox_output.upload_data(file, image_data)


# Apply bounding boxes and save for each file in the folder
for file in folder.list_paths_in_partition():
    save_unst_predictions(file, df)
