import logging
from abc import ABCMeta
from abc import abstractmethod
import numpy as np
import pandas as pd
import six
import torch

from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.deephub.deephub_scoring import DeepHubScoringEngine
from dataiku.doctor.deephub.deephub_torch_datasets import ImageClassificationDataset
from dataiku.doctor.deephub.image_classification_explaining import ImageClassificationScoreExplainer

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class AbstractImageClassificationScoringEngine(DeepHubScoringEngine):
    TYPE = "DEEP_HUB_IMAGE_CLASSIFICATION"

    def build_dataset(self, df, files_reader, model):
        """
        :type df: pd.DataFrame to be scored
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type model: dataiku.doctor.deephub.deephub_model.ComputerVisionDeepHubModel
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        transforms_lists = model.build_image_transforms_lists()
        return ImageClassificationDataset(files_reader, df, self.file_path_col, transforms_lists)

    def score(self, model, device, data_loader, deephub_logger, score_explainer):

        predictions, probas, images_ids, images_sizes = [], [], [], []
        with score_explainer:
            with torch.no_grad():
                for batch in deephub_logger.iter_over_data(data_loader, redraw_batch_if_empty=False):

                    if batch is None:
                        logger.info("Got an empty batch, not predicting it")
                        continue

                    batch_images_infos, batch_images = batch

                    batch_preds, batch_probas = data_loader.dataset.post_process_model_data(
                        self.predict_batch(model, batch_images.to(device))
                    )

                    predictions.append(batch_preds)
                    probas.append(batch_probas)
                    images_ids.append(batch_images_infos["image_id"].detach().cpu().numpy())
                    images_sizes.append(batch_images_infos["original_shape"].detach().cpu().numpy())

        df_proba_columns = ["proba_{}".format(safe_unicode_str(category))
                            for category in self.target_remapping.list_categories()]
        if len(images_ids) > 0:
            # concatenate each batch-arrays list into a single array containing all the data :
            predictions = np.concatenate(predictions)
            probas = np.concatenate(probas)
            images_ids = np.concatenate(images_ids)
            images_sizes = np.concatenate(images_sizes)

            predictions_df = pd.concat([pd.DataFrame({"image_id": images_ids, "prediction": predictions}),
                                        pd.DataFrame(data=probas, columns=df_proba_columns)], axis=1)

            if score_explainer.enabled:
                df_explanations_columns = ["explanations_{}".format(safe_unicode_str(category))
                                           for category in self.target_remapping.list_categories()]
                cams = score_explainer.compute_explanations_arrays(np.array(probas), images_sizes)
                predictions_df = pd.concat([predictions_df,
                                            pd.DataFrame(columns=df_explanations_columns, data=cams)], axis=1)

        else:
            predictions_df = pd.DataFrame(columns=["image_id", "prediction"] + df_proba_columns)

        # replace category indices with category names:
        predictions_df["prediction"] = predictions_df["prediction"].apply(self.target_remapping.get_category)

        # Some sampling strategy (e.g. distributed) can duplicate images to fill incomplete batches. Keep only 1
        # occurrence of prediction for each image
        predictions_df.drop_duplicates(subset=['image_id'], inplace=True)
        return predictions_df.set_index("image_id")

    @abstractmethod
    def predict_batch(self, model, images):
        raise NotImplementedError()


class ImageClassificationScoringEngine(AbstractImageClassificationScoringEngine):

    def predict_batch(self, model, images):
        """
        :return: 2D tensor of shape (nb_images, nb_categories) giving for each image the categories' scores.
        """
        return model(images)

    def build_score_explainer(self, deephub_model, nn_model, with_explanations, n_explanations):
        """
        :type deephub_model: ImageClassificationDeepHubModel
        :type nn_model: torch.nn.Module
        :type with_explanations: bool
        :type n_explanations: int
        :return:  dataiku.doctor.deephub.image_classification_scoring.ImageClassificationScoreExplainer
        """
        return ImageClassificationScoreExplainer(deephub_model, nn_model, with_explanations, n_explanations)

    def score(self, model, device, data_loader, deephub_logger, score_explainer):
        model.eval()
        return super(ImageClassificationScoringEngine, self).score(model, device, data_loader, deephub_logger, score_explainer)


class DummyImageClassificationScoringEngine(AbstractImageClassificationScoringEngine):
    DUMMY = True

    def __init__(self, file_path_col, target_remapping, scoring_params):
        super(DummyImageClassificationScoringEngine, self).__init__(file_path_col, target_remapping, scoring_params)
        seed = 1337
        self.random_state = np.random.RandomState(seed)
        self.torch_gen = torch.Generator().manual_seed(seed)

    def predict_batch(self, model, images):
        """
        Simulate model's output (to be able to score images & test the rest of the feature without gpus)

        :return: 2D tensor of shape (nb_images, nb_categories) with random scores.
        """
        return torch.rand(len(images), len(self.target_remapping), generator=self.torch_gen)
