import json
import logging
from typing import List, Optional, Sequence, TYPE_CHECKING

from langchain_core.documents import Document
from pydantic import BaseModel
from requests import HTTPError

from dataiku.langchain.base_rag_handler import AugmentationFallbackStrategy
from dataiku.langchain.content_part_types import ImageRetrieval, ImageRefPart, get_image_content, InlineImagePart, CaptionedImageRefPart, \
    InlineCaptionedImagePart, TextPart, MultipartContent
from dataiku.langchain.metadata_generator import DKU_MULTIMODAL_CONTENT

if TYPE_CHECKING:
    from dataiku.llm.types import TrustedObject

def from_doc(doc: Document) -> Optional['MultimodalContent']:
    multimodal_content = doc.metadata.get(DKU_MULTIMODAL_CONTENT)

    if multimodal_content is None:
        # `DKU_MULTIMODAL_CONTENT` not found in metadata. Assuming we aren't in the multimodal case.
        return None

    try:
        multimodal_content = json.loads(multimodal_content)
    except Exception as e:
        raise ValueError(f"Metadata {multimodal_content} is not a valid json", str(e))

    _type = multimodal_content.get("type")
    if _type == "text":
        return TextMultimodalContentItem(**multimodal_content)
    if _type == "images":
        return ImageMultimodalContentItem(**multimodal_content)
    if _type == "captioned_images":
        return CaptionedImageMultimodalContentItem(**multimodal_content)
    raise ValueError(f"Metadata {multimodal_content} doesn't have a valid type")


class MultimodalContent(BaseModel):
    def get_parts(self, _index, _1, _2, _3 = None) -> Sequence[MultipartContent]:
        raise NotImplementedError()


class TextMultimodalContentItem(MultimodalContent):
    content: str
    type: str = "text"

    @property
    def images(self):
        return []

    @property
    def text(self):
        return self.content

    def get_parts(self, index: Optional[int], _1: ImageRetrieval, _2: str, _3: Optional["TrustedObject"] = None) -> Sequence[TextPart]:
        return [TextPart(index, self.content)]


class ImageMultimodalContentItem(MultimodalContent):
    content: List[str]
    type: str = "images"

    @property
    def images(self):
        return self.content

    @property
    def text(self):
        return ""

    def _new_ref_part(self, index, full_folder_id, image_path, error=None):
        return ImageRefPart(index, full_folder_id, image_path, error)

    def _new_inline_part(self, index, image_bytes, mime_type, error=None):
        return InlineImagePart(index, image_bytes, mime_type, error)

    def get_parts(self, index: Optional[int], image_retrieval: ImageRetrieval, full_folder_id: str, trusted_object: Optional["TrustedObject"] = None) -> Sequence[MultipartContent]:
        parts = []
        for image_path in self.content:
            try:
                image_content, mime_type = get_image_content(image_path, full_folder_id, trusted_object)
                if image_retrieval == ImageRetrieval.IMAGE_REF:
                    parts.append(self._new_ref_part(index, full_folder_id, image_path))
                else:
                    parts.append(self._new_inline_part(index, image_content, mime_type))
            except (HTTPError, ValueError) as e:
                if image_retrieval == ImageRetrieval.IMAGE_REF:
                    parts.append(self._new_ref_part(index, full_folder_id, image_path, str(e)))
                else:
                    parts.append(self._new_inline_part(index, None, None, str(e)))
        return parts


class CaptionedImageMultimodalContentItem(ImageMultimodalContentItem):
    type: str = "captioned_images"
    caption: str = ""

    @property
    def text(self):
        return self.caption

    def _new_ref_part(self, index, full_folder_id, image_path, error=None):
        return CaptionedImageRefPart(self.caption, index, full_folder_id, image_path, error)

    def _new_inline_part(self, index, full_folder_id, image_path, error=None):
        return InlineCaptionedImagePart(self.caption, index, full_folder_id, image_path, error)


