# -*- coding: utf-8 -*-
import transformers
import pandas as pd
import dataiku
from utils import load_image
import outlines
from qwen_vl_utils import process_vision_info
import json

folder = dataiku.Folder("iJFvJqJA")

model_name = "Qwen/Qwen2-VL-7B-Instruct"

# Use quantization for efficient processing
bnb_config = transformers.BitsAndBytesConfig(load_in_8bit=True)

# Load the processor that handles image inputs
processor = transformers.AutoProcessor.from_pretrained(model_name)

# Define a JSON schema and example for extracted data
schema = """{
  "type": "object",
  "properties": {
    "company": {
      "type": "string",
      "description": "Company name on the receipt"
    },
    "address": {
      "type": "string",
      "description": "Address on the receipt"
    },
    "total": {
      "type": "number",
      "description": "Total amount written on the receipt (without currency symbol)"
    },
    "date": {
      "type": "string",
      "pattern": "([0-2][0-9]|3[0-1])/(0[1-9]|1[0-2])/20[0-2][0-9]",
      "description": "Date on the receipt (dd/mm/yyyy format)"
    }
  },
  "required": ["company", "address", "date", "total"]
}"""

#example_output = 

prompt_example = (
    "Extract company, address, date, and total from the image provided. \n" +
    "Comply with the following schema: \n" +
     """{ "company" : {company name}, "address" : {address}, "date" : {dd/mm/yyyy}, "total" : {amount including decimals without currency symbol} }"""

)

# Load the Qwen-VL model with outlines for structured output
model = outlines.models.transformers_vision(
    model_name,
    model_class=transformers.Qwen2VLForConditionalGeneration,
    processor_class=transformers.AutoProcessor,
    device="cuda",
    model_kwargs={
        "quantization_config": bnb_config,
        "device_map": "auto",
    },
)


# Initialize a dictionary to store the extracted information
data_dict = {}

# Parse the schema to automatically create the necessary columns for the DataFrame
schema_dict = json.loads(schema)

# Create empty lists for each field in the schema
for field in schema_dict["properties"]:
    data_dict[field] = []
    
data_dict["raw_output"] = []

# Iterate through the sample and extract information from each image
filenames = []
for file in folder.list_paths_in_partition():
    folder_file = file.split("/")[1]
    if folder_file == "SROIE_test_images":
        filenames.append(
            file.split("/")[-1][:-4]
        )  # Extract the filename without extension
        image = load_image(folder, file)  # Load image from the SROIE folder

        # Build input message with image and prompt
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image
                    },
                    {"type": "text", "text": prompt_example},
                ],
            }
        ]

        # Generate structured data from the image based on the schema
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        description_generator = outlines.generate.json(model, schema)

        # Prepare vision inputs
        image_inputs = process_vision_info(messages)[0]  

        result = description_generator(
            text, image_inputs
        )  # Generate extracted data from the model

        # Append the extracted information to the respective lists dynamically from the schema
        for field in schema_dict["properties"]:
            data_dict[field].append(result[field])

        data_dict["raw_output"].append(result)

# Add the filenames as a separate column
data_dict["ID"] = filenames

# Create a DataFrame with the extracted information
df = pd.DataFrame(data_dict)


# Format the 'date' column to datetime format for consistency
df["date"] = pd.to_datetime(df["date"], errors="coerce", dayfirst=True)

# Save the DataFrame to a Dataiku dataset
dataiku.Dataset("results_KIE_QwenVL2").write_with_schema(df)