import base64
import mimetypes
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Union

import puremagic
import pypdfium2 as pdfium  # type: ignore
from common.backend.constants import (
    BLACKLISTED_MIME_TYPES,
    CHAIN_MIME_TYPES,
    DOCUMENT_EXTENSIONS,
    IMAGE_EXTENSIONS,
    PROMPT_SEPARATOR_LENGTH,
)
from common.backend.models.base import MediaSummary, UploadFileError
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.file_extraction.pptx.pptx_utils import load_pptx_slides_from_summary
from common.backend.utils.picture_utils import b64encode_image_from_path
from common.backend.utils.upload_utils import get_checked_config
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import DSSLLMCompletionQueryMultipartMessage
from dataikuapi.utils import DataikuException
from werkzeug.datastructures import FileStorage

webapp_config: Dict[str, str] = dataiku_api.webapp_config


def load_pdf_images_from_bytes(pdf_bytes: bytes) -> List[str]:
    b64_images = []
    ratio = dataiku_api.webapp_config.get("pdf_as_image_resize_ratio", 1)
    try:
        pdf_document = pdfium.PdfDocument(pdf_bytes)
        for page in pdf_document:
            pil_image = page.render(scale=ratio).to_pil()
            with BytesIO() as buffered:
                pil_image.save(buffered, format="PNG")
                b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
                b64_images.append(b64_image)
        del pil_image
        del b64_image
        return b64_images
    except IOError as e:
        logger.exception(f"Unable to parse document to image: {e}")
        raise Exception(f"Unable to parse document to image: {e}")


def load_pdf_images_from_file(json_file_path: str) -> List[str]:
    b64_images = []
    ratio = dataiku_api.webapp_config.get("pdf_as_image_resize_ratio", 1)
    try:
        with dataiku_api.folder_handle.get_download_stream(json_file_path) as f:
            pdf_bytes = f.read()
            pdf_document = pdfium.PdfDocument(pdf_bytes)
        for page in pdf_document:
            pil_image = page.render(scale=ratio).to_pil()
            with BytesIO() as buffered:
                pil_image.save(buffered, format="PNG")
                b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
                b64_images.append(b64_image)
        del pil_image
        del b64_image
        return b64_images
    except IOError as e:
        logger.exception(f"Unable to parse document to image: {e}")
        raise Exception(f"Unable to parse document to image: {e}")


def file_path_to_image_parts(
    summary: MediaSummary, msg: DSSLLMCompletionQueryMultipartMessage
) -> DSSLLMCompletionQueryMultipartMessage:
    file_path: Union[str, None] = summary.get("file_path")
    original_file_name: Union[str, None] = summary.get("original_file_name")
    if file_path is None or original_file_name is None:
        raise Exception("No file path or original file name provided")
    try:
        if (
            "." in file_path
            and file_path.rsplit(".", 1)[1].lower() in DOCUMENT_EXTENSIONS
        ):
            extension = Path(original_file_name).suffix if original_file_name else ""
            logger.debug(f" extension: {extension}")
            b64_images: List[str] = []
            page_or_slide = "Page"
            if "pptx" in extension:
                b64_images = load_pptx_slides_from_summary(summary)
                page_or_slide = "Slide"
            else:
                b64_images = load_pdf_images_from_file(file_path)

            n_page: int = len(b64_images)
            for page, img_b64 in enumerate(b64_images):
                msg.with_text(
                    f"{original_file_name}: {page_or_slide} {page+1} of {n_page}"
                ).with_inline_image(img_b64)
            logger.debug(f"Document {original_file_name} converted to image and inlined in completion")
        elif (
            "." in file_path and file_path.rsplit(".", 1)[1].lower() in IMAGE_EXTENSIONS
        ):
            img_b64 = b64encode_image_from_path(file_path)
            msg.with_text(f"Image: {original_file_name}").with_inline_image(img_b64)
            logger.debug(f"Image {original_file_name} inlined in completion")
        else:
            raise Exception(f"Unknown file extension type for file {file_path}")
    except DataikuException as e:
        logger.exception(f"Dataiku API Error: {e}")
        raise Exception(f"Dataiku API Error: {e}")
    except FileNotFoundError:
        logger.exception("File not found in the managed folder.")
        raise Exception("File not found in the managed folder.")
    except IOError as e:
        logger.exception(f"I/O Error: {e}")
        raise Exception(f"I/O Error: {e}")
    except Exception as e:
        logger.exception(f"An unexpected error occurred: {e}")
        raise Exception(f"An unexpected error occurred: {e}")
    logger.debug(f"File {original_file_name} converted to image")
    return msg


def file_path_text_parts(
    summary: MediaSummary, msg: DSSLLMCompletionQueryMultipartMessage
) -> DSSLLMCompletionQueryMultipartMessage:
    original_file_name: Union[str, None] = summary.get("original_file_name")

    folder = dataiku_api.folder_handle
    try:
        metadata_path = summary.get("metadata_path")  # type: ignore
        if not metadata_path:
            logger.error(f"metadata_path is not provided for document")
            raise Exception("metadata_path is not provided for document")
        extract_summary = folder.read_json(metadata_path)
        logger.debug(f"Extracting text from {metadata_path}")
        extracted_text = extract_summary.get("full_extracted_text", "No text extracted")
        msg.with_text(f"""{'-'*PROMPT_SEPARATOR_LENGTH} START OF DOCUMENT: {original_file_name} {'-'*PROMPT_SEPARATOR_LENGTH}
        Document Name: {original_file_name}
        {'Main Topics: '+str(extract_summary.get("topics"))if extract_summary.get("topics") else ''}
        Full Extracted Text:
        {extracted_text}
        {'-'*PROMPT_SEPARATOR_LENGTH} END OF DOCUMENT: {original_file_name}{'-'*PROMPT_SEPARATOR_LENGTH}
        """)
        if len(extracted_text) > 4000:
            logger.debug(f"""Text extracted from {original_file_name}:
                Extracted Text (TRUNCATED):{extracted_text[:2000]}
                ...
                ...
                ...
                {extracted_text[-2000:]}""")
        else:
            logger.debug(f"""
                Text extracted from {original_file_name}:
                Extracted Text: {extracted_text}
            """)
    except DataikuException as e:
        logger.exception(f"Dataiku API Error: {e}")
        raise Exception(f"Dataiku API Error: {e}")
    except FileNotFoundError:
        logger.exception("File not found in the managed folder.")
        raise Exception("File not found in the managed folder.")
    except IOError as e:
        logger.exception(f"I/O Error: {e}")
        raise Exception(f"I/O Error: {e}")
    except Exception as e:
        logger.exception(f"An unexpected error occurred: {e}")
        raise Exception(f"An unexpected error occurred: {e}")

    return msg


def allowed_file(file: FileStorage, multi_modal: bool) -> str:
    filename = str(file.filename)
    mimetype: Optional[str] = file.mimetype
    logger.debug(f"File name: {filename}, MIME type: {mimetype}")

    # When called from ask endpoint mime_type can be application/octet-stream.
    # In this case we detect the real MIME type for security.
    if mimetype and mimetype in "application/octet-stream":
        mimetype = get_file_mime_type(filename)
        logger.info(f"Detected actual MIME type for '{filename}': {mimetype}")
    
    if not mimetype:
        raise Exception(UploadFileError.INVALID_FILE_TYPE.value)

    allowed_extensions = DOCUMENT_EXTENSIONS
    allowed_mimetypes = CHAIN_MIME_TYPES["DOCUMENT"]
    if multi_modal:
        allowed_extensions = allowed_extensions.union(IMAGE_EXTENSIONS)
        allowed_mimetypes = allowed_mimetypes.union(CHAIN_MIME_TYPES["IMAGE"])
    if "." not in filename:
        raise Exception(UploadFileError.INVALID_FILE_TYPE.value)
    extension = filename.rsplit(".", 1)[1].lower()

    extension_permitted = extension in allowed_extensions
    mimetype_permitted = mimetype in allowed_mimetypes
    not_blacklisted_mimetype = mimetype not in BLACKLISTED_MIME_TYPES
    if not all ((extension_permitted, mimetype_permitted, not_blacklisted_mimetype)):
        logger.debug(f"extension_permitted: {extension_permitted}, mimetype_permitted: {mimetype_permitted}, not_blacklisted_mimetype: {not_blacklisted_mimetype}")
        raise Exception(UploadFileError.INVALID_FILE_TYPE.value)
    return extension


def get_file_data(file: FileStorage) -> bytes:
    file_data: bytes
    max_size_mb = int(get_checked_config("max_upload_size_mb"))
    max_content_length = max_size_mb * 1024 * 1024
    file_data = file.read()
    if len(file_data) == 0:
        raise Exception(UploadFileError.NO_SELECTED_FILE.value)
    if len(file_data) > max_content_length:
        raise Exception(UploadFileError.FILE_TOO_LARGE.value)
    file.seek(0)
    return file_data


def delete_files(file_paths: List[str]) -> None:
    folder = dataiku_api.folder_handle
    for path in file_paths:
        try:
            folder.delete_path(path)
        except Exception as e:
            logger.exception(f"Error occurred while deleting file: {e}")


def get_file_mime_type(file_name: str) -> Optional[str]:
    """
    Will find the mime_type of a file based on its filename.
    """
    # Not known by default so we need to add it
    mimetypes.add_type('image/jpeg', '.jfif')
    mimetypes.add_type('text/markdown', '.md')
    mime_type, _ = mimetypes.guess_type(file_name)
    
    if mime_type is None: 
        logger.warn(f"MIME type could not be determined for the file {file_name}")

    return mime_type


def normalize_extension(extension: str) -> str:
    """
    Add a dot at the extension begining if not there.
    Switch the extension to lower case

    Note: Sometimes puremagic can retrieve the extension without the dot 
        (e.g.: 'pptx' instead of '.pptx')
    """
    if not extension.startswith("."):
        return "." + extension.lower()
    return extension.lower()


def file_extensions_are_equivalent(ext_a: str, ext_b: str) -> bool:
    """
    Compare file extensions to check if they're part of the same category.
    """
    if not ext_a and ext_b:
        logger.debug(f"Cannot compare extensions : {ext_a} and {ext_b}")
        return False
    mime_type_a = get_file_mime_type(f"file{ext_a}")
    mime_type_b = get_file_mime_type(f"file{ext_b}")
    logger.debug(f"{ext_a} have MIME type {mime_type_a} and {ext_b} have MIME type {mime_type_b}")
    return mime_type_a == mime_type_b


def get_file_format(filename: str)->str:
    return Path(filename).suffix.lower()


def file_is_likely_plain_text(
    file: FileStorage,
    allow_partial_utf8: bool = True,
    non_printable_threshold: float = 0.3,
    max_check_size: int = 2048
) -> bool:
    """
    Analyze if a file is likely a safe plain text file.

    Args:
        file (FileStorage): The file object to analyze, typically from a web upload. 
        allow_partial_utf8 (bool): 
            If True, allows partial UTF-8 decoding (ignores errors).
            If False, rejects file if it cannot be fully decoded.
        non_printable_threshold (float): Maximum allowed ratio (0.0–1.0) of non-printable bytes.
            For example: 0.3 = tolerate up to 30% weird characters.
        max_check_size (int): Number of bytes to analyze from the file start (sample size).

    Returns:
        - is_safe (bool): True if file is probably safe text.
    """
    if not file: 
        logger.error("File input is empty or invalid.")
        return False 

    file_bytes = file.read()
    filename = file.filename if file.filename else "unknown_file"
    file.seek(0)

    # UTF-8 decoding attempt
    try:
        file_bytes.decode("utf-8")
        utf8_status = 'ok' 
    except UnicodeDecodeError:
        if allow_partial_utf8:
            file_bytes.decode("utf-8", errors="ignore")
            utf8_status = "partial"
            logger.warn(f"File '{filename}' contains characters that could not be fully decoded to UTF-8. Proceeding with partial decoding.") 
        else:
            logger.error(f"File '{filename}' failed UTF-8 decoding and partial decoding is not allowed.") 
            return False

    # Analyze printable characters in sample
    sample = file_bytes[:max_check_size]
    # Logic to identify non-printable characters: 
    # - b < 9: Includes ASCII control characters like NULL, SOH, etc. (values 0-8). 
    # - (b > 13 and b < 32): Includes other ASCII control characters (values 14-31), 
    #                        excluding carriage return (13) and newline (10) which are often considered printable/formatting. 
    # - b > 126: Includes the ASCII DELETE control character (127) and all extended non-ASCII (Unicode) 
    #            characters that have a byte value greater than 127. Even if these characters are valid UTF-8, 
    #            this function considers them "non-printable" in the context of this strict safety check. 
    non_printable_count = sum(
        1 for b in sample
        if b < 9 or (b > 13 and b < 32) or b > 126
    )
    total = len(sample) if sample else 1
    non_printable_ratio = non_printable_count / total

    if non_printable_ratio > non_printable_threshold:
        logger.error(f"File '{filename}' exceeds non-printable character threshold. Ratio: {non_printable_ratio:.2f} (Allowed: {non_printable_threshold:.2f}).")
        return False

    logger.debug(f"File '{filename}' assessed as a safe plain text file. UTF-8 status: {utf8_status}, Non-printable ratio: {non_printable_ratio:.2f}.") 
    return True


def is_file_coherent(file: FileStorage) -> bool:
    """
    Check if the file content is aligned with its extension.

    based on puremagic it will check the file header
    (first few KB is enough for detection, because file formats
    can be identified from their "magic bytes" — unique binary signatures located 
    at the beginning of the file, such as 0x89504E47 for PNG or 0x25504446 for PDF)

    If the file is not a binary file (md, txt..) the function will check its content
    to see if there's some non textual things 

    :param file: Werkzeug FileStorage object
    :return: True if file content matches the extension, False otherwise
    """
    ALLOWED_EQUIVALENTS = {
        ".docx": [".docx", ".pptx", ".xlsx"],
        ".pptx": [".pptx", ".docx", ".xlsx"],
        ".xlsx": [".xlsx", ".docx", ".pptx"],
    }

    PLAIN_TEXT_EXTENSIONS = [".md", ".txt", ".js", ".py", ".html"]

    if file is None or file.filename is None:
        logger.warn("File is None or has no filename.")
        return False
    
    try:
        content = file.read(8192)
        file.stream.seek(0)  # reset stream pointer
        actual_extension = normalize_extension(Path(file.filename).suffix)

        with tempfile.NamedTemporaryFile(delete=True) as tmp:
            tmp.write(content)
            tmp.flush()

            detected_extension = normalize_extension(puremagic.from_file(tmp.name))
            if not detected_extension:
                logger.warn(f"No extension has been detected for {file.filename}")
                return actual_extension in PLAIN_TEXT_EXTENSIONS and file_is_likely_plain_text(file)

            # A simple string comparison fails for aliases. This checks for semantic equivalence
            # by comparing the extensions MIME types (e.g., .jpg and .jfif are both 'image/jpeg').
            if file_extensions_are_equivalent(detected_extension, actual_extension):
                logger.debug(f"Coherent file {file.filename} has been detected with extension {detected_extension}")
                return True
            else:
                # Some MS Office formats (docx, pptx, xlsx) share the same ZIP header — allow them to be interchangeable here.
                if actual_extension in ALLOWED_EQUIVALENTS.get(detected_extension, [detected_extension]):
                    logger.debug(
                        f"Allowed equivalent for Office file: detected {detected_extension}, "
                        f"but actual extension is {actual_extension} — accepted as valid."
                    )
                    return True

                logger.warn(
                    f"Incoherent file {file.filename}: extension {actual_extension} "
                    f"doesn't match detected {detected_extension}"
                )

    except Exception as e:
        error_message = str(e)
        if error_message == 'Could not identify file':
            logger.warn(f"File could not be identified by puremagic.")
            if actual_extension in PLAIN_TEXT_EXTENSIONS:
                # Textual files doesn't have headers to check so we verify
                # there's nothing weird in it
                return file_is_likely_plain_text(file)
        logger.warn(f"Something went wrong while checking coherence of {file.filename} : {e}")
    return False