import base64
import enum
import logging
import os

from requests import HTTPError, utils

from abc import ABC, abstractmethod
from typing import Optional, List, Dict, Sequence

from dataiku.langchain.multimodal_content import MultimodalContent


class ImageRetrieval(enum.Enum):
    IMAGE_REF = "IMAGE_REF",
    IMAGE_INLINE = "IMAGE_INLINE"


class MultipartContent(ABC):

    def __init__(self, index: Optional[int]):
        # index stays optional without default since it is always expected but can be passed as None. See `MetadataHandler.get_multipart_content`
        self.index = index
        self.type: Optional[str] = None

    @abstractmethod
    def to_text(self) -> str:
        pass


class TextPart(MultipartContent):

    def __init__(self, index: Optional[int], text: str):
        super().__init__(index)
        self.type = "TEXT"
        self.text = text

    def to_text(self) -> str:
        return self.text


class ImagePart(MultipartContent, ABC):

    def __init__(self, index: Optional[int]):
        super().__init__(index)


class InlineImagePart(ImagePart):

    def __init__(self, index: Optional[int], image_bytes: bytes, mime_type: Optional[str]):
        super().__init__(index)
        self.type = "IMAGE_INLINE"
        self.inline_image = base64.b64encode(image_bytes).decode("utf8")
        self.image_mime_type = mime_type

    def to_text(self) -> str:
        return self.inline_image


class ImageRefPart(ImagePart):

    def __init__(self, index: Optional[int], full_folder_id: str, path: str):
        super().__init__(index)
        self.type = "IMAGE_REF"
        self.full_folder_id = full_folder_id
        self.path = path

    def to_text(self) -> str:
        return self.full_folder_id + self.path


def get_image_parts(multimodal_content: MultimodalContent, index: Optional[int], image_retrieval: ImageRetrieval, full_folder_id: str) -> Sequence[
    ImagePart]:
    parts: List[ImagePart] = []
    # Cache image content if multiple docs reference the same image (overlapping)
    image_cache: Dict = {}
    for image_path in multimodal_content.content:
        if image_retrieval == ImageRetrieval.IMAGE_REF:
            parts.append(ImageRefPart(index, full_folder_id, image_path))
            continue

        if image_path in image_cache:
            parts.append(InlineImagePart(index, image_cache[image_path][0], image_cache[image_path][1]))
        else:
            try:
                from dataiku.core import intercom
                import dataiku

                project_key, lookup = full_folder_id.split(".", 1)
                download_response = intercom.backend_api_get_call(
                    "managed-folders/download-path?projectKey=" + project_key + "&lookup=" + lookup + "&path=" + utils.quote(image_path), None)

                if download_response.status_code == 200:
                    mime_type = "image/" + os.path.splitext(image_path)[1][1:]

                    image_cache[image_path] = [download_response.content, mime_type]
                    parts.append(InlineImagePart(index, download_response.content, mime_type))
                else:
                    # we don't want to fail during augmentation to allow the base model to answer even without the augmentation parts.
                    logging.warning("Error when retrieving file {image_path} in folder {folder_id} : {err_msg} - skipping"
                                    .format(image_path=image_path,
                                            folder_id=full_folder_id,
                                            err_msg=download_response.content),
                                    )
            except HTTPError as e:
                # we don't want to fail during augmentation to allow the base model to answer even without the augmentation parts.
                logging.warning("Error when retrieving file {image_path} in folder {folder_id} : {err_msg} - skipping"
                                .format(image_path=image_path,
                                        folder_id=full_folder_id,
                                        err_msg=str(e)),
                                )
    return parts
