import json
import logging
from typing import List, Optional, Sequence

from langchain_core.documents import Document
from pydantic import BaseModel
from requests import HTTPError

from dataiku.langchain.content_part_types import ImageRetrieval, ImageRefPart, get_image_content, InlineImagePart, CaptionedImageRefPart, \
    InlineCaptionedImagePart, TextPart
from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT


def from_doc(doc: Document) -> Optional['MultimodalContent']:
    multimodal_content = doc.metadata.get(DKU_MULTIMODAL_CONTENT)

    if multimodal_content is None:
        logging.info(
            f"{DKU_MULTIMODAL_CONTENT} Not found in metadata. Available fields are: {doc.metadata.keys()}. Assuming we aren't in the multimodal case.")
        return None

    try:
        multimodal_content = json.loads(multimodal_content)
    except Exception as e:
        raise ValueError(f"Metadata {multimodal_content} is not a valid json", str(e))

    _type = multimodal_content.get("type")
    if _type == "text":
        return TextMultimodalContentItem(**multimodal_content)
    if _type == "images":
        return ImageMultimodalContentItem(**multimodal_content)
    if _type == "captioned_images":
        return CaptionedImageMultimodalContentItem(**multimodal_content)
    raise ValueError(f"Metadata {multimodal_content} doesn't have a valid type")



class MultimodalContent(BaseModel):
    def get_parts(self, _index, _1, _2) -> Sequence:
        raise NotImplementedError()


class TextMultimodalContentItem(MultimodalContent):
    content: str
    type: str = "text"

    @property
    def images(self):
        return []

    @property
    def text(self):
        return self.content

    def get_parts(self, index: Optional[int], _1: ImageRetrieval, _2: str) -> Sequence[TextPart]:
        return [TextPart(index, self.content)]

class ImageMultimodalContentItem(MultimodalContent):
    content: List[str]
    type: str = "images"

    @property
    def images(self):
        return self.content

    @property
    def text(self):
        return ""

    def _new_ref_part(self, index, full_folder_id, image_path):
        return ImageRefPart(index, full_folder_id, image_path)

    def _new_inline_part(self, index, full_folder_id, image_path):
        return InlineImagePart(index, full_folder_id, image_path)

    def get_parts(self, index: Optional[int], image_retrieval: ImageRetrieval, full_folder_id: str) -> Sequence:
        parts = []
        for image_path in self.content:
            if image_retrieval == ImageRetrieval.IMAGE_REF:
                parts.append(self._new_ref_part(index, full_folder_id, image_path))
                continue
            try:
                image_content, mime_type = get_image_content(image_path, full_folder_id)
                parts.append(self._new_inline_part(index, image_content, mime_type))
            except (HTTPError, ValueError) as e:
                # we don't want to fail during augmentation to allow the base model to answer even without the augmentation parts.
                logging.warning("Error when retrieving file {image_path} in folder {folder_id} : {err_msg} - skipping"
                                .format(image_path=image_path,
                                        folder_id=full_folder_id,
                                        err_msg=str(e)),
                                )
        return parts

class CaptionedImageMultimodalContentItem(ImageMultimodalContentItem):
    type: str = "captioned_images"
    caption: str = ""

    @property
    def text(self):
        return self.caption

    def _new_ref_part(self, index, full_folder_id, image_path):
        return CaptionedImageRefPart(self.caption, index, full_folder_id, image_path)

    def _new_inline_part(self, index, full_folder_id, image_path):
        return InlineCaptionedImagePart(self.caption, index, full_folder_id, image_path)


