import base64
import io
import logging
import timm
import torch
import os

from PIL import Image
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSingleEmbeddingCommand
from dataiku.huggingface.types import ProcessSingleEmbeddingResponse


class ModelPipelineImageEmbeddingExtraction(ModelPipelineBatching[ProcessSingleEmbeddingCommand, ProcessSingleEmbeddingResponse]):

    def __init__(self, model_path: str, hf_model_name: str, use_dss_model_cache: bool, batch_size: int):
        super().__init__(batch_size=batch_size)
        self.transform_func, self.embedding_extraction_model = self._initialize_model_and_transform(hf_model_name, model_path, use_dss_model_cache)

    def _initialize_model_and_transform(self, hf_model_name, model_path, use_dss_model_cache):
        """
        Load the timm pretrained model onto the selected device & create the transform functions associated to the model
        :return: transform function & timm model
        """
        # setting pretrained to True will result in timm fetching weights and their config from HF
        # when using the DSS model cache we want to support air-gaped instances so we set it to false
        pretrained = True
        checkpoint_path = ""
        if use_dss_model_cache:
            checkpoint_path = os.path.join(model_path, "model.safetensors")
            pretrained = False
            if not os.path.isfile(checkpoint_path):
                checkpoint_path = os.path.join(model_path, "pytorch_model.bin")
                if not os.path.isfile(checkpoint_path):
                    raise Exception("Model weight file not found in the dss model cache.")
        embedding_extraction_model = timm.create_model(hf_model_name,
                                                       pretrained=pretrained,
                                                       checkpoint_path=checkpoint_path)
        embedding_extraction_model = embedding_extraction_model.eval()
        embedding_extraction_model.reset_classifier(0)  # remove classifier to output the features not the predictions
        logging.info("Loaded pretrained timm image embedding extraction model with model name: {}".format(hf_model_name))

        # create appropriate transform function for the selected model:
        # this will handle the preprocessing of the images (including resizing - & normalisation)
        model_config = timm.data.resolve_data_config(embedding_extraction_model.pretrained_cfg)
        logging.info("Resolved preprocessing Transform function config: {}".format(model_config))
        transform_func = timm.data.create_transform(**model_config)

        logging.info("Setting embedding extraction model to device '{}'".format(self._get_device()))
        # todo @Multimodal: put transform_func onto device as well?
        embedding_extraction_model.to(self._get_device())

        self.used_engine = "timm"
        self.model_tracking_data["task"] = "image-embedding"
        self.model_tracking_data["model_architecture"] = embedding_extraction_model.pretrained_cfg.get('architecture')

        return transform_func, embedding_extraction_model

    def _get_device(self):
        """
        Get the formatted string of the computation device to use according to the params
        :rtype: str
        """
        if torch.cuda.device_count() > 0:
            return "cuda:0"
        return "cpu"

    def _get_inputs(self, requests: List[ProcessSingleEmbeddingCommand]) -> List[str]:
        return [request["query"]["inlineImage"] for request in requests]

    def _run_inference(self, image64_batch: List[str], params: Dict) -> List:
        # What this list comprehension does in the following order:
        #  1 - decode the base 64 encoded image (and put the decoded bytes in a buffer)
        #  2 - get a PIL image instance
        #  3 - convert the PIL image to the RGB format
        #  4 - apply the transform functions (Normalization & resizing depending on model chosen)
        images_list = [self.transform_func(Image.open(io.BytesIO(base64.b64decode(image64))).convert("RGB")) for image64 in image64_batch]

        images_tensor = torch.stack(images_list)  # stack all images from the batch list into a single tensor
        images_tensor = images_tensor.to(self._get_device())  # moving data onto the selected device (cpu/gpu)

        with torch.no_grad():
            embeddings = self.embedding_extraction_model(images_tensor)  # the embedding extraction operation
            embeddings = embeddings.cpu().numpy()  # moving data back to the cpu and converting it to the numpy format
            return embeddings.tolist()

    def _parse_response(self, response: List[float], request: ProcessSingleEmbeddingCommand) -> ProcessSingleEmbeddingResponse:
        return {"embedding": response}
