import numpy as np
from matplotlib import cm

from dataiku.doctor.deephub.deephub_explaining import AbstractDeepHubScoreExplainer
from dataiku.doctor.deephub.utils.file_utils import img_array_to_base64


class ImageClassificationScoreExplainer(AbstractDeepHubScoreExplainer):
    def __init__(self, deephub_model, nn_model, with_explanations, n_explanations):
        super(ImageClassificationScoreExplainer, self).__init__(deephub_model, nn_model, with_explanations, n_explanations)

        self.last_layer = deephub_model.get_last_layer_module(nn_model)
        self.weight_fc = np.squeeze(deephub_model.get_classifier_parameters(nn_model))

        # set at run time:
        self.hook = None
        self.last_layer_outputs_by_batch = []
        self.last_layer_outputs_aggregated = None

    def initialize_model_for_explanations(self):
        """ if explanations are requested, register a hook in the last convolutions block of the model to store its
            outputs during inference. This outputs are needed to compute Class Activation Maps.
        """
        self.last_layer_outputs_by_batch = []
        self.last_layer_outputs_aggregated = None
        self.hook = self.last_layer.register_forward_hook(self.hook_fn)

    def free_model_from_explainer(self):
        # detach model hook and retrieve outputs of the last conv block accumulated during inference:
        self.hook.remove()
        if self.n_explanations > 0 and len(self.last_layer_outputs_by_batch) > 0:
            # concatenate list of all batches into a single matrix
            self.last_layer_outputs_aggregated = np.concatenate(self.last_layer_outputs_by_batch)

    def hook_fn(self, _module, _input, output):
        """ This hook stores the output of the module to which it was attached into a last_block_outputs list
            during inference.
        """
        self.last_layer_outputs_by_batch.append(output.detach().cpu().numpy())

    def compute_explanations_arrays(self, categories_probas, original_img_shapes):
        if len(self.last_layer_outputs_aggregated) == 0:
            raise Exception("Did not accumulate any output during inference, cannot compute explanations")

        nb_imgs, nb_conv, h, w = self.last_layer_outputs_aggregated.shape
        feature_convs = self.last_layer_outputs_aggregated.reshape((nb_imgs, nb_conv, h * w))

        color_map = cm.get_cmap('viridis')  # keep in sync with deephub-image-classification-prediction-widget.component.ts
        n_explanations = min(self.n_explanations, categories_probas.shape[1])

        # sort by descending proba
        predicted_categories = np.fliplr(np.argsort(categories_probas, axis=1))

        cams = []
        for i, (feature_conv_single_img, predicted_cat_single_img, original_img_shape) in \
                enumerate(zip(feature_convs, predicted_categories, original_img_shapes)):

            img_cams = self.compute_image_cams(feature_conv_single_img, self.weight_fc, h, w)

            # transform to base64 & store only the topN cams per image & fill the other cells with NaN.
            cams.append([self.transform_cam_to_base64_img(cam, img_size=original_img_shape[:2], color_map=color_map)
                         if category_idx in predicted_cat_single_img[:n_explanations] else np.nan
                         for category_idx, cam in enumerate(img_cams)])
        return cams


    @staticmethod
    def compute_image_cams(feature_conv, weights_fc, cam_height, cam_width):
        """ Compute the classes activation maps for a single image for all the categories.
            return Matrix of size (nb_categories, h * w) which correspond to 1 CAM per class,
            normalized between 0 and 1.
        """
        cams_all_categories = weights_fc.dot(feature_conv)

        # normalize row by row:
        cams_all_categories = cams_all_categories - np.min(cams_all_categories, axis=1)[:, None]
        cams_all_categories = np.divide(cams_all_categories, np.max(cams_all_categories, axis=1)[:, None])
        return cams_all_categories.reshape(-1, cam_height, cam_width)

    @staticmethod
    def transform_cam_to_base64_img(cam, img_size, color_map):
        """
            Transform the cam matrix (single value per pixel) into a RGB image with the color map palette applied
            and interpolated to the img_size given.
            Then transform this image into a base64 string.
            :return: base64 string representing the cam image created.
        """
        # local import to allow unit tests running on builtin env to import this file.
        import skimage.transform

        img = skimage.transform.resize(cam, img_size)  # perform interpolation
        img = color_map(img)[:, :, :3]  # grey scale to RGB colors (ignoring alpha channel, opacity is chosen in the UI)

        return img_array_to_base64(img_array=img * 255, img_format="JPEG")
