# -*- coding: utf-8 -*-
import dataiku
import tempfile
import os
import io
import pandas as pd
from unstructured.partition.pdf import partition_pdf  # For extracting elements from PDF files

# Initialize Dataiku folders for source PDFs and target images
folder = dataiku.Folder("Zcjq6uFf")  # Source folder ID
images_folder = dataiku.Folder("vOjkXoGz")  # Target folder ID

def extract_pdf_elements(source, target):
    """
    Extract images, tables, and chunk text from a PDF file.

    Parameters:
    - source: str, path to the source PDF file
    - target: str, path to the target directory for extracted images

    Returns:
    - list, extracted elements from the PDF
    """
    return partition_pdf(
        filename=source,  # Path to the PDF file to be processed
        extract_image_block_types=["Image", "Table"],  # Types of blocks to extract
        infer_table_structure=True,  # Whether to infer the structure of tables
        chunking_strategy="by_title",  # Strategy for chunking text
        strategy="hi_res",  # Strategy for high resolution extraction
        max_characters=4000,  # Maximum number of characters per chunk
        combine_text_under_n_chars=2000,  # Combine text blocks with fewer characters than this
        extract_image_block_output_dir=target,  # Directory to save extracted images
    )

def format_elements(elements, source_filename, max_distance=400, tol=50):
    """
    Process the elements returned by `unstructured` and return 2 dataframes (texts and images).
    
    Parameters:
    - elements: list, elements extracted from a PDF by `partition_pdf`
    - source_filename: str, name of the source file (used for indexing)
    - max_distance: int, maximum distance to consider for caption-image association
    - tol: int, tolerance for bounding box alignment

    Returns:
    - texts_df: DataFrame, processed text elements
    - figures_df: DataFrame, processed image elements
    """
    indices = set()  # Set to keep track of unique image/table indices

    # Initialize dictionaries to hold data for texts and images
    figures = {"index": [], "category": [], "points": [], "page": [], "content": []}
    texts = {"index": [], "category": [], "points": [], "text": [], "page": []}

    # Iterate through elements and their sub-elements to populate figures and texts
    for i in range(len(elements)):
        for j in range(len(elements[i].metadata.fields.get("orig_elements"))):
            sub_element = elements[i].metadata.fields.get("orig_elements")[j]
            category = sub_element.category
            if category in ["Image", "Table"]:
                # Process images and tables
                image_url = sub_element.metadata.fields.get("image_path").split("/")[-1]
                index = f"{source_filename}_{image_url}"
                if index not in indices:
                    indices.add(index)
                    figures["category"].append(category)
                    figures["points"].append(sub_element.metadata.fields.get("coordinates").points)
                    figures["index"].append(index)
                    figures["page"].append(sub_element.metadata.fields.get("page_number"))
                    figures["content"].append(sub_element.text)
            else:
                # Process other text elements
                texts["index"].append(f"{source_filename}_text-{i}-{j}")
                texts["category"].append(category)
                texts["points"].append(sub_element.metadata.fields.get("coordinates").points)
                texts["text"].append(sub_element.text)
                texts["page"].append(sub_element.metadata.fields.get("page_number"))

    # Convert dictionaries to DataFrames and set the index
    texts_df = pd.DataFrame.from_dict(texts).set_index("index")
    figures_df = pd.DataFrame.from_dict(figures).set_index("index")

    # Identify text boxes as potential captions for images based on bounding box coordinates
    for i in texts_df.index:
        texts_df.at[i, "neighbor"], texts_df.at[i, "container"] = "", ""
        if texts_df.at[i, "category"] not in ["FigureCaption", "NarrativeText", "UncategorizedText"]:
            continue
        points = texts_df.at[i, "points"]
        x11, x12, y11, y12 = points[0][0], points[2][0], points[0][1], points[2][1]
        page = texts_df.at[i, "page"]
        figures_same_page_df = figures_df[figures_df["page"] == page]
        neighbors = []
        container = -1
        for j in figures_same_page_df.index:
            figure_points = figures_same_page_df.at[j, "points"]
            x21, x22, y21, y22 = figure_points[0][0], figure_points[2][0], figure_points[0][1], figure_points[2][1]
            if x11 > x21 - tol and x12 < x22 + tol:
                distance = max_distance + 1
                dy = y21 - y12
                if dy > 0 and dy < max_distance:
                    distance = dy
                dy = y11 - y22
                if dy > 0 and dy < max_distance:
                    distance = min(dy, distance)
                if distance <= max_distance:
                    neighbors.append((j, distance))
                if y11 > y21 and y12 < y22:
                    container = j
        if len(neighbors) > 0:
            neighbor_index = min(neighbors, key=lambda x: x[1])
            texts_df.at[i, "neighbor"] = neighbor_index[0]
        if container != -1:
            texts_df.at[i, "container"] = container

    # Add captions to figures and images
    figures_df["caption"] = ""
    for i in texts_df.index:
        container = texts_df.at[i, "container"]
        if container == container and len(container) > 0:
            if figures_df.at[container, "content"] != figures_df.at[container, "content"]:
                figures_df.at[container, "content"] = texts_df.at[i, "text"]
            else:
                figures_df.at[container, "content"] += ", " + texts_df.at[i, "text"]
        neighbor = texts_df.at[i, "neighbor"]
        if neighbor != "":
            if figures_df.at[neighbor, "caption"] == "":
                figures_df.at[neighbor, "caption"] = texts_df.at[i, "text"]
            else:
                figures_df.at[neighbor, "caption"] += ", " + texts_df.at[i, "text"]

    # Clean up DataFrame and reset indices
    del texts_df["neighbor"]
    texts_df = texts_df.reset_index()
    figures_df = figures_df.reset_index()
    texts_df["filename"] = source_filename
    figures_df["filename"] = source_filename

    return texts_df, figures_df  # Return the processed DataFrames


# Initialize DataFrames for texts and figures
texts_df, figures_df = None, None

# Create a temporary directory to hold the downloaded PDF files
with tempfile.TemporaryDirectory() as temp_dir:
    # Create another temporary directory to hold the extracted elements (e.g., images)
    with tempfile.TemporaryDirectory() as temp_dir2:
        # Iterate over all files in the specified Dataiku folder
        for f in folder.list_paths_in_partition():
            # Get the filename from the path
            source_filename = os.path.basename(f)
            # Download the PDF file from the Dataiku folder
            with folder.get_download_stream(f) as stream:
                filepath = os.path.join(temp_dir, source_filename)  # Path to save the PDF file
                # Write the downloaded PDF to the temporary directory
                with open(filepath, "wb") as f2:
                    f2.write(stream.read())
            # Extract elements (text, images, tables) from the PDF
            elements = extract_pdf_elements(filepath, temp_dir2)
            # Format the extracted elements into DataFrames
            new_texts_df, new_figures_df = format_elements(elements, source_filename)
            # Combine the new DataFrames with the existing ones
            if texts_df is None:
                texts_df = new_texts_df
            else:
                texts_df = pd.concat([texts_df, new_texts_df], ignore_index=True)
            if figures_df is None:
                figures_df = new_figures_df
            else:
                figures_df = pd.concat([figures_df, new_figures_df], ignore_index=True)

            # Remove the processed PDF file
            os.remove(filepath)
            # Iterate over all extracted files in the temporary directory
            for root, dirs, files in os.walk(temp_dir2):
                for file in files:
                    source_path = os.path.join(root, file)
                    target_path = f"{source_filename}_{file}"
                    # Upload the extracted image to the Dataiku folder
                    images_folder.upload_file(target_path, source_path)
                    # Remove the processed image file
                    os.remove(source_path)

dataiku.Dataset("texts").write_with_schema(texts_df)
dataiku.Dataset("images").write_with_schema(figures_df)
