# Citation: 
# Berrios, W., Mittal, G., Thrush, T., Kiela, D., & Singh, A. (2023). Towards Language Models 
# That Can See: Computer Vision Through the LENS of Natural Language. arXiv preprint arXiv:2306.16410.
# The code is from the GitHub repository : https://github.com/ContextualAI/lens
# Minor modifications have been applied and comments have been added.

import os
from pathlib import Path
from typing import Any, List, Optional

import huggingface_hub
import open_clip
import torch
import torch.nn as nn
from datasets import Dataset, load_dataset
from transformers import (
    AutoProcessor,
    BlipForConditionalGeneration,
    CLIPModel,
    CLIPProcessor,
)

from .utils import (
    create_dataloader,
    create_prompt_sample,
    create_sampler,
    default_device,
)

def flatten(l):
    """
    Flatten a nested list.
    """
    return [item for sublist in l for item in sublist]

class Lens(nn.Module):
    """
    Multimodal AI model combining CLIP and BLIP models for image and text processing.
    """
    
    def __init__(
        self,
        clip_name: str = "openai/clip-vit-large-patch14",  # Default CLIP model name
        blip_name: str = "Salesforce/blip-image-captioning-large",  # Default BLIP model name
        attributes_weights: str = "zw_attributes_laion_ViT_H_14_2B_descriptors_text_davinci_003_full.pt",  # File path for attribute weights
        tags_weights: str = "zw_tags_laion_ViT_H_14_2B_vocab_lens.pt",  # File path for tag weights
        vocab_attributes: str = "llm-lens/descriptors-text-davinci-003",  # Vocabulary file path for attributes
        vocab_tags: str = "llm-lens/vocab_tags",  # Vocabulary file path for tags
        split_attributes: str = "full",  # Split for attributes
        split_tags: str = "train",  # Split for tags
        load_8bit: bool = False,  # Whether to load weights in 8-bit
        device: torch.device = default_device,  # Device to use
        hf_transformers_home_dir: str = os.getenv("HF_HOME"),  # Home directory for Hugging Face models
    ):
        """
        Initialize the Lens model.

        Parameters:
            clip_name (str, optional): CLIP model name. Defaults to "openai/clip-vit-large-patch14".
            blip_name (str, optional): BLIP model name. Defaults to "Salesforce/blip-image-captioning-large".
            attributes_weights (str, optional): File path for attribute weights. Defaults to "zw_attributes_laion_ViT_H_14_2B_descriptors_text_davinci_003_full.pt".
            tags_weights (str, optional): File path for tag weights. Defaults to "zw_tags_laion_ViT_H_14_2B_vocab_lens.pt".
            vocab_attributes (str, optional): Vocabulary file path for attributes. Defaults to "llm-lens/descriptors-text-davinci-003".
            vocab_tags (str, optional): Vocabulary file path for tags. Defaults to "llm-lens/vocab_tags".
            split_attributes (str, optional): Split for attributes. Defaults to "full".
            split_tags (str, optional): Split for tags. Defaults to "train".
            load_8bit (bool, optional): Whether to load weights in 8-bit. Defaults to False.
            device (torch.device, optional): Device to use. Defaults to default_device.
            hf_transformers_home_dir (str, optional): Home directory for Hugging Face models. Defaults to os.getenv("HF_HOME").
        """
        super().__init__() # Call the superclass initializer
        
        # Store parameters
        self.hf_transformers_home_dir = hf_transformers_home_dir  # Set home directory for Hugging Face models
        self.device = device  # Set the device
        self.clip_name = None  # Placeholder for CLIP model name
        self.blip_name = blip_name  # Set the BLIP model name
        
        # Load CLIP model and weights if specified
        if self.clip_name is not None:
            self.clip_model = self.load_clip_model(self.clip_name, self.device) # Load CLIP model
            # Load attribute and tag weights
            self.attributes_weights = torch.load(
                str(
                    self.hf_transformers_home_dir + "/" +
                    f"weights/{attributes_weights}"
                ),
                map_location=self.device,
            ).float() # Load attribute weights
            self.tags_weights = torch.load(
                self.hf_transformers_home_dir + "/" + f"weights/{tags_weights}",
                map_location=self.device,
            ).float() # Load tag weights

        # Load BLIP model and processor if specified
        if self.blip_name is not None:
            self.blip_model = self.load_caption_model(
                self.blip_name, load_8bit, self.device
            )
            self.blip_processor = AutoProcessor.from_pretrained(self.blip_name)

    def load_caption_model(
        self, model_name: str, load_8bit: bool, device: torch.device
    ):
        """
        Load the caption model.

        Parameters:
            model_name (str): The name of the caption model to load.
            load_8bit (bool): Whether to load the model in 8-bit format.
            device (torch.device): The device to move the model to.

        Returns:
            torch.nn.Module: The loaded caption model.
        """
        # Load the BLIP model based on the model name and device
        if load_8bit:
            model = BlipForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch.float32 if device == "cpu" else torch.float16,
                device_map={"": device},
                load_in_8bit=True
            )
        else:
            model = BlipForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch.float32 if device == "cpu" else torch.float16
            )
        model = model.eval()  # Set the model to evaluation mode
        model = model.to(device)  # Move the model to the specified device

        return model

    def load_clip_model(
        self, model_name: str, device: torch.device
    ):
        """
        Load the CLIP model.

        Parameters:
            model_name (str): The name of the CLIP model to load.
            device (torch.device): The device to move the model to.

        Returns:
            torch.nn.Module: The loaded CLIP model.
        """
        # Load the CLIP model based on the model name and device
        if "openai" in model_name:
            model = CLIPModel.from_pretrained(model_name).to(device) 

        elif "laion" in model_name:
            model = open_clip.create_model_and_transforms(model_name)[0].to(device)
        return model

    def __call__(
        self,
        samples: dict,
        num_tags: int = 5,
        num_attributes: int = 5,
        contrastive_th: float = 0.2,
        num_beams: int = 5,  # For beam search
        max_length: int = 30,
        min_length: int = 10,
        top_k: int = 50,
        num_captions: int = 10,
        return_tags: bool = False,
        return_attributes: bool = False,
        return_global_caption: bool = True,
        return_intensive_captions: bool = True,
        return_complete_prompt: bool = True,
    ):
        """
        Process the samples using the Lens model.

        Parameters:
            samples (dict): A dictionary containing the samples to be processed.
            num_tags (int): The number of tags to generate.
            num_attributes (int): The number of attributes to generate.
            contrastive_th (float): The threshold for contrastive filtering.
            num_beams (int): The number of beams for beam search.
            max_length (int): The maximum length for generated sequences.
            min_length (int): The minimum length for generated sequences.
            top_k (int): The top-k value for sampling during intensive caption generation.
            num_captions (int): The number of intensive captions to generate.
            return_tags (bool): Whether to include tags in the processed samples.
            return_attributes (bool): Whether to include attributes in the processed samples.
            return_global_caption (bool): Whether to include a global caption in the processed samples.
            return_intensive_captions (bool): Whether to include intensive captions in the processed samples.
            return_complete_prompt (bool): Whether to include a complete prompt in the processed samples.

        Returns:
            dict: The processed samples.
        """
        
        # Main function for processing samples and generating prompts
        if return_tags:
            # Process samples to generate tags
            samples = self.forward_tags(
                samples, num_tags=num_tags, contrastive_th=contrastive_th
            )
        if return_attributes:
            # Process samples to generate attributes
            samples = self.forward_attributes(
                samples, num_attributes=num_attributes, contrastive_th=contrastive_th
            )
        if return_global_caption:
            # Process samples to generate global captions
            samples = self.forward_caption(
                samples,
                num_beams=num_beams,
                max_length=max_length,
                min_length=min_length,
            )
        if return_intensive_captions:
            # Process samples to generate intensive captions
            samples = self.forward_intensive_caption(
                samples,
                max_length=max_length,
                min_length=min_length,
                top_k=top_k,
                num_captions=num_captions,
            )

        if return_complete_prompt:
            # Generate complete prompt from processed samples
            samples = self.create_prompt_from_samples(samples)

        return samples

    def forward_tags(
        self, samples: dict, num_tags: int = 5, contrastive_th: float = 0.2
    ):
        """
        Forward pass for tags recognition.
        """
        
        # Function to process samples and generate tags
        tags = []
        try:
            image_features = self.clip_model.encode_image(
                samples["clip_image"].to(self.device)
            )
        except:
            image_features = self.clip_model.get_image_features(
                pixel_values=samples["clip_image"]
            )
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_scores = image_features @ self.tags_weights
        top_scores, top_indexes = text_scores.float().cpu().topk(k=num_tags, dim=-1)
        for scores, indexes in zip(top_scores, top_indexes):
            filter_indexes = indexes[scores >= contrastive_th]
            if len(filter_indexes) > 0:
                top_k_tags = [self.vocab_tags[index] for index in filter_indexes]
            else:
                top_k_tags = []
            tags.append(top_k_tags)
        samples[f"tags"] = tags
        return samples

    def forward_attributes(
        self, samples: dict, num_attributes: int = 5, contrastive_th: float = 0.2
    ):
        """
        Forward pass for tags recognition.

        Parameters:
            samples (dict): A dictionary containing the samples for tag recognition.
            num_tags (int): The number of tags to generate.
            contrastive_th (float): The threshold for contrastive filtering.

        Returns:
            dict: The samples with generated tags.
        """
        # Get Image Features
        attributes = []
        try:
            image_features = self.clip_model.encode_image(
                samples["clip_image"].to(self.device)
            )
        except:
            image_features = self.clip_model.get_image_features(
                pixel_values=samples["clip_image"]
            )
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_scores = image_features @ self.attributes_weights
        top_scores, top_indexes = (
            text_scores.float().cpu().topk(k=num_attributes, dim=-1)
        )
        for scores, indexes in zip(top_scores, top_indexes):
            filter_indexes = indexes[scores >= contrastive_th]
            if len(filter_indexes) > 0:
                top_k_tags = [self.vocab_attributes[index] for index in filter_indexes]
            else:
                top_k_tags = []
            attributes.append(top_k_tags)
        samples[f"attributes"] = attributes
        return samples

    def forward_caption(
        self,
        samples: dict,
        num_beams: int = 5,
        max_length: int = 30,
        min_length: int = 10,
    ):
        """
        Forward pass for caption recognition.

        Parameters:
            samples (dict): A dictionary containing the samples for caption recognition.
            num_beams (int): Number of beams for beam search.
            max_length (int): Maximum length of generated captions.
            min_length (int): Minimum length of generated captions.

        Returns:
            dict: The samples with generated captions.
        """
        # Beam search
        captions_list = []
        pixel_values = samples["blip_image"].to(self.device, self.blip_model.dtype)
        input_ids = samples["blip_input_ids"].to(self.device)
        captions_ids = self.blip_model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            do_sample=False,
            num_beams=num_beams,
            top_p=1,
            max_length=max_length,
            min_length=min_length,
        )

        captions = self.blip_processor.batch_decode(
            captions_ids, skip_special_tokens=True
        )

        for caption in captions:
            captions_list.append(caption[12:].strip())

        samples["caption"] = captions_list
        return samples

    def forward_intensive_caption(
        self,
        samples: dict,
        max_length: int = 30,
        min_length: int = 10,
        top_k: int = 50,
        num_captions: int = 10,
    ):
        """
        Forward pass for intensive caption recognition.

        Parameters:
            samples (dict): A dictionary containing the samples for intensive caption recognition.
            max_length (int): Maximum length of generated captions.
            min_length (int): Minimum length of generated captions.
            top_k (int): Number of top-k tokens to sample from during generation.
            num_captions (int): Number of intensive captions to generate.

        Returns:
            dict: The samples with generated intensive captions.
        """
        pixel_values = samples["blip_image"].to(self.device, self.blip_model.dtype)
        input_ids = samples["blip_input_ids"].to(self.device)
        caption_ids = self.blip_model.generate(
            pixel_values=pixel_values,
            input_ids=input_ids,
            max_length=max_length,
            min_length=min_length,
            do_sample=True,
            top_p=1,
            top_k=top_k,
            repetition_penalty=1,
            num_return_sequences=num_captions,
        )

        captions_text = self.blip_processor.batch_decode(
            caption_ids, skip_special_tokens=True
        )
        captions_text = [caption[12:].strip() for caption in captions_text]
        captions_text = [
            captions_text[i : i + num_captions]
            for i in range(0, len(captions_text), num_captions)
        ]
        samples["intensive_captions"] = captions_text
        return samples

    # This function could be more efficient
    def create_prompt_from_samples(
        self,
        samples: dict,
        mode: str = "all",  # vqa or vision or hm or or all
    ):
        """
        Generate prompts from samples.

        This function creates prompts based on the provided samples.

        Parameters:
            samples (dict): A dictionary containing the samples.
            mode (str, optional): Mode for generating prompts. Defaults to "all".

        Returns:
            dict: The samples with generated prompts.
        """
        num_samples = samples["clip_image"].shape[0]
        prompts = []
        for idx in range(num_samples):
            prompt = create_prompt_sample(samples, idx, mode=mode)

            prompts.append(prompt)
        samples["prompts"] = prompts
        return samples

    def hf_dataset_transform(
        self,
        ds: Dataset,
        processor: "LensProcessor",
        num_tags: int = 5,
        num_attributes: int = 5,
        contrastive_th: float = 0.2,
        num_beams: int = 5,  # For beam search
        max_length: int = 30,
        min_length: int = 10,
        top_k: int = 50,
        num_captions: int = 10,
        return_tags: bool = True,
        return_attributes: bool = True,
        return_global_caption: bool = True,
        return_intensive_captions: bool = True,
        distributed_sampling: bool = False,
        batch_size: int = 8,
        num_workers: int = 0,
    ):
        
        """
        Transform an input dataset using the Lens model.

        This function transforms the input dataset using the Lens model to generate tags, attributes, captions, and intensive captions.

        Parameters:
            ds (Dataset): Input dataset.
            processor (LensProcessor): Lens processor instance.
            num_tags (int, optional): Number of tags to generate. Defaults to 5.
            num_attributes (int, optional): Number of attributes to generate. Defaults to 5.
            contrastive_th (float, optional): Contrastive threshold. Defaults to 0.2.
            num_beams (int, optional): Number of beams for beam search. Defaults to 5.
            max_length (int, optional): Maximum length of generated sequences. Defaults to 30.
            min_length (int, optional): Minimum length of generated sequences. Defaults to 10.
            top_k (int, optional): Top K value for beam search. Defaults to 50.
            num_captions (int, optional): Number of captions to generate. Defaults to 10.
            return_tags (bool, optional): Whether to return generated tags. Defaults to True.
            return_attributes (bool, optional): Whether to return generated attributes. Defaults to True.
            return_global_caption (bool, optional): Whether to return global captions. Defaults to True.
            return_intensive_captions (bool, optional): Whether to return intensive captions. Defaults to True.
            distributed_sampling (bool, optional): Whether to use distributed sampling. Defaults to False.
            batch_size (int, optional): Batch size. Defaults to 8.
            num_workers (int, optional): Number of workers for data loading. Defaults to 0.

        Returns:
            Dataset: Transformed dataset.
        """
        dataset = LensDataset(ds, None, processor)
        # Create sampler
        sampler = create_sampler(dataset, distributed=distributed_sampling)
        # Create Dataloader
        dataloader = create_dataloader(
            dataset, sampler, batch_size=batch_size, num_workers=num_workers
        )

        # Get tags, attributes, caption, intensive_captions
        result = []
        for batch in dataloader:
            with torch.no_grad():
                batch = self(
                    batch,
                    num_tags=num_tags,
                    num_attributes=num_attributes,
                    contrastive_th=contrastive_th,
                    num_beams=num_beams,  # For beam search
                    max_length=max_length,
                    min_length=min_length,
                    top_k=top_k,
                    num_captions=num_captions,
                    return_tags=return_tags,
                    return_attributes=return_attributes,
                    return_global_caption=return_global_caption,
                    return_intensive_captions=return_intensive_captions,
                )

                keys = [
                    key
                    for key in batch.keys()
                    if key
                    in ["id", "tags", "attributes", "caption", "intensive_captions"]
                ]
                for tuples in zip(*[batch[key] for key in keys]):
                    result.append(
                        {
                            k: (v.item() if k == "id" else v)
                            for k, v in zip(keys, tuples)
                        }
                    )

        if distributed_sampling is False:
            # To-Do: Add new columns to the huggingface dataset
            dict_ = {}
            for res in result:
                dict_[res["id"]] = {k: v for k, v in res.items() if k != "id"}

            # Map new columns would be faster
            def add_info(example):
                for component in [
                    "tags",
                    "attributes",
                    "caption",
                    "intensive_captions",
                ]:
                    try:
                        example[component] = dict_[example["id"]][component]
                    except:
                        pass
                return example

            result_ds = ds.map(add_info, batched=False)
            return result_ds
        else:
            # Only return the new componenets
            result_ds = Dataset.from_dict(
                {key: [d[key] for d in result] for key in result[0]}
            )
            return result_ds


class LensProcessor:
    def __init__(
        self,
        clip_name: str = "openai/clip-vit-large-patch14",
        blip_name: str = "Salesforce/blip-image-captioning-large",
        hf_transformers_home_dir: str = os.getenv("HF_HOME"),

    ):
        """
        Initialize the Lens Processor.

        Parameters:
            blip_name (str, optional): Name of the BLIP model. Defaults to "Salesforce/blip-image-captioning-large".
            hf_transformers_home_dir (str, optional): Directory for HF Transformers. Defaults to os.getenv("HF_HOME").
        """
        self.hf_transformers_home_dir = hf_transformers_home_dir
        self.clip_processor = self.load_clip_transform(clip_name)
        self.blip_processor = AutoProcessor.from_pretrained(blip_name)

    def load_clip_transform(self, model_name: str):
        """
        Load the CLIP Processor.

        Parameters:
            model_name (str): Name of the CLIP model.

        Returns:
            CLIPProcessor: Loaded CLIP Processor.
        """
        if "openai" in model_name:
            return CLIPProcessor.from_pretrained(model_name)

        elif "laion" in model_name:
            return open_clip.create_model_and_transforms(model_name)[2]

    def __call__(self, images: Any,):
        """
        Process images using the Lens Processor.

        Parameters:
            images (Any): Images to process.

        Returns:
            dict: Processed images.
        """
        try:
            clip_image = torch.stack([self.clip_processor(image) for image in images])
        except:
            clip_image = self.clip_processor(images=images, return_tensors="pt")[
                "pixel_values"
            ]
        outputs = self.blip_processor(
            images=images, text=["a picture of"] * len(images), return_tensors="pt"
        )
        blip_image = outputs["pixel_values"]
        blip_input_ids = outputs["input_ids"]
        return {
            "clip_image": clip_image,
            "blip_image": blip_image,
            "blip_input_ids": blip_input_ids,
            #"questions": questions,
        }


class LensDataset:
    def __init__(
        self,
        ds: Dataset,
        questions: Optional[List[str]] = None,
        processor: Optional[LensProcessor] = None,
    ):
        """
        Initialize the Lens Dataset.

        Parameters:
            ds (Dataset): Input dataset.
            questions (Optional[List[str]], optional): List of questions. Defaults to None.
            processor (Optional[LensProcessor], optional): Lens Processor. Defaults to None.
        """
        self.ds = ds
        self.processor = processor
        self.questions = questions

    def __getitem__(self, idx):
        """
        Get item from the dataset.

        Parameters:
            idx (int): Index of the item.

        Returns:
            dict: Processed item.
        """
        image = self.ds[idx]["image"]
        id = self.ds[idx]["id"]
        try:
            question = self.ds[idx]["question"]
        except:
            pass
        try:
            question = self.questions[idx]
        except:
            question = ""
        outputs = self.processor([image], question)
        return {
            "id": torch.tensor(id, dtype=torch.int32),
            "clip_image": outputs["clip_image"].squeeze(0),
            "blip_image": outputs["blip_image"].squeeze(0),
            "blip_input_ids": outputs["blip_input_ids"].squeeze(0),
            "questions": outputs["questions"],
        }

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Length of the dataset.
        """
        return len(self.ds)
