import asyncio
import base64
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, asdict
from io import BytesIO
from typing import List, Optional, Tuple, Iterator

import docx
import pptx
from docling.datamodel.pipeline_options import OcrOptions
from docx.oxml import CT_P, CT_Tbl
from docx.table import _Cell, Table
from pypdfium2 import PdfDocument

from dataiku.base.batcher import Batcher
from dataiku.llm.docextraction import TextNode, RootNode, build_message_log_for_document, ImageNode, SectionNode, AbstractDocumentNode, SlideNode
from dataiku.llm.docextraction.ocr import get_ocr_config, process_images_with_ocr, pdf_to_pil_images_iterator, convert_image_to_greyscale_bytes

logger = logging.getLogger(__name__)


@dataclass
class RawRequest:
    file_name: str
    document_content: str
    do_ocr: bool
    ocr_engine: str
    lang: str


@dataclass
class RawResponse:
    ok: bool
    resp: dict
    error: Optional[str]

    def to_dict(self):
        return asdict(self)


class RawExtractorPipeline:
    batcher: Batcher[RawRequest, dict]
    image_formats: List[str] = ["png", "jpg", "jpeg"]
    supported_formats: List[str] = ["pdf", "docx", "pptx", "html"] + image_formats
    document_batch_size: int = 4
    memory_limit_per_document: int = 50

    def __init__(self, kernel_settings):
        self.executor = ThreadPoolExecutor()

        if "documentBatchSize" in kernel_settings:
            self.document_batch_size = kernel_settings.get("documentBatchSize")
        if "pageBatchSize" in kernel_settings:
            self.page_batch_size = kernel_settings.get("pageBatchSize")
        if "memoryLimitPerDocument" in kernel_settings:
            self.memory_limit_per_document = kernel_settings.get("memoryLimitPerDocument")

        self.batcher = Batcher[RawRequest, dict](
            batch_size=self.document_batch_size,
            timeout=1,
            process_batch=self._process_batch_async,
            group_by=lambda request: hash((os.path.splitext(request.file_name)[1], request.do_ocr, tuple(request.lang)))
        )
        logger.info("Raw extractor pipeline started with settings: doc_batch_size=%d, page_batch_size=%d",
                    self.document_batch_size, self.page_batch_size)

    def _run_batch_sync(self, requests: List[RawRequest]) -> List[dict]:
        logger.info("Processing a batch of %s document extraction requests" % len(requests))
        return self.batch_raw_extract(requests)

    async def _process_batch_async(self, requests: List[RawRequest]) -> List[dict]:
        return await asyncio.get_running_loop().run_in_executor(self.executor, self._run_batch_sync, requests)

    async def process_document(self, process_document_command) -> dict:
        logger.info("Processing a document request")
        ocr_settings = process_document_command.get("ocrSettings", {})
        return await self.batcher.process(RawRequest(process_document_command["fileName"], process_document_command["documentContent"],
                                                     process_document_command.get("imageHandlingMode", "IGNORE") == "OCR",
                                                     ocr_settings.get("ocrEngine", "AUTO"), ocr_settings.get("ocrLanguages", [])))

    def batch_raw_extract(self, requests: List[RawRequest]) -> List[dict]:
        """
        Extract text from a list of requests

        If OCR is activated:
         - For image files, we run OCR.
         - PDFs are converted into images and processed with OCR.

        If OCR is not activated:
         - Image files are not processed.
         - PDFs are processed with pypdfium2 to extract text.

        In all cases, DOCX and PPTX files are processed with python libraries.

        :param requests: a list of documents with the same extension. They also share in common the same settings for OCR.
        :return: a list of responses transformed into dicts
        """
        if not requests:
            return []

        extension = os.path.splitext(requests[0].file_name)[1].lower().lstrip(".")

        needs_ocr = any(req.do_ocr for req in requests)

        if needs_ocr:
            ocr_options = get_ocr_config(requests[0].lang, requests[0].ocr_engine, False)
        else:
            ocr_options = None

        # Handle image formats without docling:
        if extension in self.image_formats:
            res = []
            if ocr_options is None:
                raise ValueError("OCR options must be provided for image extraction")
            text_results = process_images_with_ocr(ocr_options, [(idx, request.document_content) for idx, request in enumerate(requests)])
            for idx, request in enumerate(requests):
                if idx in text_results:
                    res.append(RawResponse(True, ImageNode(node_id=request.file_name,
                                                           children=[],
                                                           label="image",
                                                           content=text_results[idx],
                                                           level=0,
                                                           # The field name is misleading but it is used for file name in docling_extraction
                                                           image_base64=request.file_name,
                                                           ).to_dict(), None).to_dict())
                else:
                    # something went wrong with the OCR processing
                    res.append(RawResponse(False, {}, f"Error processing image {request.file_name} with OCR").to_dict())
            return res

        # Any other documents (docx, pptx, html)
        document_to_process = [(document.file_name, base64.b64decode(document.document_content)) for document in requests]

        res = []
        for idx, (filename, results) in enumerate(extract_all(document_to_process, self.memory_limit_per_document, ocr_options, extension=extension)):
            try:
                root = RootNode(
                    node_id="document",
                    children=results,
                    label=filename,
                    content="",
                    level=0,
                    page_provenance=[]
                )
                res.append(
                    RawResponse(True, root.to_dict(), None).to_dict())
                logger.info(build_message_log_for_document(filename, "Done processing document"))
            except Exception as e:
                logger.exception("An error occurred during raw text extraction")
                res.append(RawResponse(False, {}, ''.join(traceback.format_exception(type(e), e, e.__traceback__))).to_dict())
        return res


def extract_all(files: List[Tuple[str, bytes]], memory_limit_per_document, ocr_options: Optional[OcrOptions], extension: str):
    for file_name, file_bytes in files:
        yield file_name, extract_text_chunks(file_bytes, memory_limit_per_document, ocr_options, extension, True)


def extract_text_chunks(file_bytes: bytes, memory_limit_per_document, ocr_options: Optional[OcrOptions], extension: str, use_pdf_bookmarks: bool):
    if extension in ["doc", "ppt"]:
        raise ValueError(f"'{extension}' files are not supported, try to convert them to {extension}x.")

    if extension == "pdf":
        pdf_pages = PdfDocument(file_bytes)
        if ocr_options is not None:
            chunks = []
            for j, img in enumerate(pdf_to_pil_images_iterator(file_bytes, memory_limit_per_document)):
                pil_image = convert_image_to_greyscale_bytes(img)
                label = "Page {}".format(j + 1)
                img_text = process_images_with_ocr(ocr_options, images_ref=[(label, pil_image)])[label]
                page_provenance = [j + 1]
                chunks.append(ImageNode(
                    node_id=str(j + 1),
                    children=[],  # List of children DocumentNode
                    content=img_text,  # Content, text or base64 for images
                    label=label,
                    level=1,  # Level within the hierarchy
                    page_provenance=page_provenance,
                ))
            return chunks
        else:
            bookmarks = list(pdf_pages.get_toc())
            if len(bookmarks) == 0 or not use_pdf_bookmarks:
                # only extract page numbers when no bookmarks are found
                return [
                    TextNode(
                        node_id=str(page_id + 1),
                        children=[],
                        content=page.get_textpage().get_text_range(),
                        label="Page {}".format(page_id + 1),
                        level=1,
                        page_provenance=[page_id + 1]
                    )
                    for page_id, page in enumerate(pdf_pages)
                ]
            else:
                return _extract_pdf_chunks(pdf_pages, bookmarks)
    elif extension == "docx":
        doc = docx.Document(BytesIO(file_bytes))
        return [
            TextNode(
                node_id="fulltext",
                children=[],
                content="\n".join([paragraph.text for paragraph in extract_text_from_docx(doc) if paragraph.text != ""]),
                label="Text",
                level=1,
                page_provenance=[],
            )
        ]
    elif extension == "pptx":
        doc = pptx.Presentation(BytesIO(file_bytes))
        return extract_text_from_pptx(doc)
    else:
        return [TextNode(
            node_id="0",
            children=[],
            content=file_bytes.decode(),
            label="",
            level=1,
            page_provenance=[],
        )]


def _extract_text_from_pdf_bound(pdf_pages, start_page, start_vertical_position, end_page, end_vertical_position):
    """
    Extract text between a starting vertical position in a starting page and an ending vertical position in an ending page.
    """
    text = ""
    while start_page < end_page:
        text += pdf_pages[start_page].get_textpage().get_text_bounded(top=start_vertical_position)
        start_page += 1
        start_vertical_position = None
    text += pdf_pages[start_page].get_textpage().get_text_bounded(top=start_vertical_position, bottom=end_vertical_position)
    return text


def _extract_pdf_chunks(pdf_pages: PdfDocument, bookmarks: List):
    """
    Extract chunks from a PDF bookmarks, creating a nested hierarchy
    where deeper levels become children of previous parent levels.
    """
    chunks: List[AbstractDocumentNode] = []

    # --- 1. Header Extraction (Pre-bookmark content) ---
    try:
        # Note: Checks the Y-pos of the first bookmark to determine header boundary
        end_vertical_position = bookmarks[0].view_pos[1]
    except (IndexError, AttributeError):
        end_vertical_position = None

    text = _extract_text_from_pdf_bound(
        pdf_pages,
        start_page=0,
        start_vertical_position=None,
        end_page=bookmarks[0].page_index if bookmarks else 0,
        end_vertical_position=end_vertical_position
    )

    if text.strip():
        header_node = SectionNode(
            node_id="header",
            children=[TextNode(
                node_id="1",
                content=text,
                level=2,
                page_provenance=[i for i in range(1, bookmarks[0].page_index)] if bookmarks else [],
                children=[],
                label="Header",
            )],
            level=1,
            label="header",
        )
        chunks.append(header_node)

    # --- 2. Bookmark Hierarchy Construction ---
    # Stack to keep track of the current branch [parent, child, grandchild...]
    stack: List[SectionNode] = []

    for bookmark_id, bookmark in enumerate(bookmarks):
        title = bookmark.title
        level = bookmark.level
        start_page = bookmark.page_index
        try:
            start_vertical_position = bookmark.view_pos[1]
        except (IndexError, AttributeError):
            start_vertical_position = None

        last_header = (bookmark_id == len(bookmarks) - 1)
        end_page = bookmarks[bookmark_id + 1].page_index if not last_header else len(pdf_pages) - 1
        try:
            end_vertical_position = bookmarks[bookmark_id + 1].view_pos[1] if not last_header else None
        except (IndexError, AttributeError):
            end_vertical_position = None

        text = _extract_text_from_pdf_bound(
            pdf_pages, start_page, start_vertical_position, end_page, end_vertical_position
        )

        if text.startswith(title):
            text = text[len(title):]

        text_node = TextNode(
            node_id=str(bookmark_id + 1),
            children=[],
            content=text.strip(),
            label="paragraph",
            level=level + 2,
            page_provenance=[i for i in range(start_page + 1, end_page + 2)],
        )
        new_section = SectionNode(
            node_id="section-{}".format(bookmark_id + 1),
            children=[text_node],
            content=title,
            level=level + 1,
            page_provenance=[],
            label="section",
        )

        # --- Hierarchy Logic ---

        # 1. Backtrack: If the new section is not deeper than the current top of stack,
        # we have finished the previous subsection(s). Pop until we find the parent.
        # (e.g. if Stack is [Lvl1, Lvl2], and New is Lvl2, pop Lvl2. Parent is Lvl1)
        while stack and stack[-1].level and new_section.level and stack[-1].level >= new_section.level:
            stack.pop()

        # 2. Append: Add the new section to the hierarchy
        if stack:
            # We found a parent in the stack (stack[-1].level < new_section.level)
            stack[-1].children.append(new_section)
        else:
            # Stack empty or all popped implies this is a top-level root node
            chunks.append(new_section)

        # 3. Push: This new section is now the active parent for subsequent deeper nodes
        stack.append(new_section)

    return chunks


def iter_visible_row_cells(row) -> Iterator[_Cell]:
    """Generate only "concrete" cells, those with a `tc` element.

    Vertically spanned cells have a `tc` element but are skipped.
    """
    yield from (_Cell(tc, row) for tc in row._tr.tc_lst if tc.vMerge != "continue")


def extract_text_from_docx(parent):
    if isinstance(parent, docx.document.Document):
        parent_elm = parent.element.body
    elif isinstance(parent, _Cell):
        parent_elm = parent._tc
    else:
        raise ValueError("something's not right")

    for child in parent_elm.iterchildren():
        if isinstance(child, CT_P):
            yield docx.text.paragraph.Paragraph(child, parent)
        elif isinstance(child, CT_Tbl):
            table = Table(child, parent)
            for row in table.rows:
                for cell in iter_visible_row_cells(row):
                    yield from extract_text_from_docx(cell)


def extract_text_from_pptx(prs) -> List[SlideNode]:
    """
    Extracts all text from a .pptx file, including text inside tables.
    """
    full_text = []

    for slide_index, slide in enumerate(prs.slides):
        slide_node = SlideNode(
            node_id=f"slide-{slide_index + 1}",
            children=[],
            label="slide",
            level=1,
            page_provenance=[slide_index + 1],
        )
        slide_text = []

        for shape in slide.shapes:
            if shape.has_text_frame and shape.text:
                slide_text.append(shape.text)

            elif shape.has_table:
                for row in shape.table.rows:
                    for cell in row.cells:
                        if cell.text_frame.text:
                            slide_text.append(cell.text_frame.text)

        if slide_text:
            slide_node.children = [TextNode(node_id=f"slide-{slide_index + 1}-text-{idx + 1}",
                                            content=text, level=2, children=[], label="text",
                                            page_provenance=[slide_index + 1]) for idx, text in enumerate(slide_text)]

        full_text.append(slide_node)

    return full_text
