import logging
import re

import numpy as np
import pandas as pd

import dataiku

from six.moves import xrange

from dataiku.base.utils import safe_unicode_str, RaiseWithTraceback
from dataiku.doctor import is_global_embedding_cache_enabled
from dataiku.doctor.preprocessing.dataframe_preprocessing import Step

logger = logging.getLogger(__name__)

ONE_WHITE_PIXEL_IMAGE_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" # 1 white pixel Image.new('RGB', (1, 1), (255, 255, 255))
MAX_CONCURRENT_IMAGES_IN_MEM = 50

IMG_TO_EMBEDDINGS_CACHE = {}


class ImageEmbeddingExtractor(Step):

    def __init__(self, input_col, file_reader, model_structured_ref, impute_missing_values=False, impute_invalid_paths=False):
        super(ImageEmbeddingExtractor, self).__init__()
        self._input_col = input_col
        self._project_handle = dataiku.api_client().get_default_project()
        self._llm_model = self._project_handle.get_llm(model_structured_ref)
        self._model_ref = model_structured_ref
        self._impute_missing_values = impute_missing_values
        self._impute_invalid_paths = impute_invalid_paths
        self._embedding_size = None # will be set later
        self._base64_regexp = re.compile("^([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{3}=|[A-Za-z0-9+/]{2}==)$")

        # import here so it does not get imported when this file is imported by Python 2.7 during preprocessing 
        from dataiku.core.image_loader import ImageLoader
        self._image_loader = ImageLoader(fail_if_invalid_path=not self._impute_invalid_paths, file_reader=file_reader)

    def __str__(self,):
        return "Step:%s (%s)" % (self.__class__.__name__, self._input_col)

    def process(self, input_df, current_mf, output_ppr, generated_features_mapping):
        """
        Vectorize a column calling the embedding API using LLM connections to compute the image embeddings.

        :param input_df: Contains path to the image to be processed
        :param current_mf: Multiframe to which to append processed data
        :param output_ppr: Current output preprocessing result
        :param generated_features_mapping: Holds mappings of block<->column
        """
        embeddings = self.extract_embeddings(input_df)

        # Note: When used on an input column called "paths", generated features will be named:
        # img_emb:paths:0, img_emb:paths:1 ... img_emb:paths:embedding_size-1  (where embedding size depends on the chosen model)
        feat_names = [u"img_emb:{}:{}".format(safe_unicode_str(self._input_col), i) for i in xrange(self._embedding_size)]
        block_name = u"image_embedding:{}".format(safe_unicode_str(self._input_col))
        generated_features_mapping.add_features_to_block(feat_names, block_name)
        generated_features_mapping.add_whole_block_mapping(block_name, [self._input_col])
        current_mf.append_np_block(block_name, embeddings, feat_names)

    def extract_embeddings(self, input_df):
        file_paths = input_df[self._input_col].values
        is_dummy_df = hasattr(input_df, '_dku_is_dummy_df') and input_df._dku_is_dummy_df

        if len(file_paths) == 0:
            # scoring an empty dataset should generate a consistent number of features.
            # note: DSS default dtype is float64, better stay consistent there:
            embeddings = np.empty((0, self._get_embedding_size()), dtype=np.float64)
        elif is_dummy_df:
            embeddings = self._init_embeddings_array(len(file_paths))
        else:
            embeddings = self._images_to_embeddings(file_paths)
        return embeddings

    def _init_embeddings_array(self, size):
        return np.zeros((size, self._get_embedding_size()), dtype=np.float16)

    def _get_embedding_size(self):
        if self._embedding_size:
            return self._embedding_size

        # we query the backend for a dummy image embeddings just to get the size
        backend_query = self._llm_model.new_embeddings()
        backend_query.add_image(ONE_WHITE_PIXEL_IMAGE_B64)
        self._embedding_size = len(backend_query.execute().get_embeddings()[0])
        return self._embedding_size

    def _images_to_embeddings(self, img_paths):
        output = self._init_embeddings_array(len(img_paths))

        not_missing_img_paths_indices = [index for (index, img_path) in enumerate(img_paths) if
                                         (img_path and not pd.isna(img_path) and not img_path.isspace())]

        if len(not_missing_img_paths_indices) != len(img_paths) and not self._impute_missing_values:
            raise Exception(u"Missing image path for feature '{}'. "
                            u"Could not process the file".format(safe_unicode_str(self._input_col)))

        update_cache = is_global_embedding_cache_enabled()
        logger.info("Caching preprocessing image embeddings in memory" if update_cache else "Not caching image embeddings in memory")
        output[not_missing_img_paths_indices] = self._get_images_embeddings(img_paths[not_missing_img_paths_indices],
                                                                            update_cache=update_cache)
        return output

    # We must use a key including the model in case this code is called twice if
    # different models are used for the same preprocessing
    def _get_image_key(self, img_path):
        return self._model_ref + "::" + img_path

    # Important note: if this is called from the API Node, the update cache flag should be set to false
    def _get_images_embeddings(self, not_missing_img_paths, update_cache):
        output = self._init_embeddings_array(len(not_missing_img_paths))

        num_extracted_embeddings = 0  # for logging purposes only
        # No more than MAX_CONCURRENT_IMAGES_IN_MEM images will be loaded in memory at the same time
        # Intrinsically this means we don't send more than MAX_CONCURRENT_IMAGES_IN_MEM to the embedding API
        # No more than MAX_CONCURRENT_IMAGES_DOWNLOADS images will be downloaded at the same time
        for batch_index in range(0, len(not_missing_img_paths), MAX_CONCURRENT_IMAGES_IN_MEM):
            indices_to_load = np.arange(len(not_missing_img_paths))[batch_index:batch_index + MAX_CONCURRENT_IMAGES_IN_MEM]

            not_computed_img_paths_indices = []
            for index in indices_to_load:
                key = self._get_image_key(not_missing_img_paths[index])
                if key in IMG_TO_EMBEDDINGS_CACHE:
                    output[index, :] = IMG_TO_EMBEDDINGS_CACHE[key]
                else:
                    not_computed_img_paths_indices.append(index)

            if len(not_computed_img_paths_indices) == 0:
                continue

            with RaiseWithTraceback(u"Error while loading images "
                                    u"from feature '{}'".format(safe_unicode_str(self._input_col))):
                b64_images = self._image_loader.load_images(
                    img_paths=not_missing_img_paths[not_computed_img_paths_indices])

            backend_query = self._llm_model.new_embeddings()
            indices_to_embed = []
            for b64_img, index in zip(b64_images, not_computed_img_paths_indices):
                if b64_img is not None:
                    backend_query.add_image(image=b64_img)
                    indices_to_embed.append(index)

            # no image to embed in this batch (paths were invalid so none was loaded)
            if len(indices_to_embed) == 0:
                continue

            logger.info("Sending {num_images} images to the embedding "
                        "api.".format(num_images=len(indices_to_embed)))

            embeddings = backend_query.execute().get_embeddings()
            assert len(indices_to_embed) == len(embeddings)
            num_extracted_embeddings += len(embeddings)

            for index, img_embeddings in zip(indices_to_embed, embeddings):
                np_img_embeddings = np.array(img_embeddings, dtype=np.float16)  # 'downcast' to save space
                output[index, :] = np_img_embeddings

            if update_cache:
                for index in not_computed_img_paths_indices:
                    img_path = not_missing_img_paths[index]
                    IMG_TO_EMBEDDINGS_CACHE[self._get_image_key(img_path)] = output[index, :]
                logger.info(u"Image embeddings cache size: {}".format(len(IMG_TO_EMBEDDINGS_CACHE)))

        logger.info("Extracted embeddings for {num_images} images from "
                    "the embedding api.".format(num_images=num_extracted_embeddings))

        return output
