import logging
from abc import ABCMeta, abstractmethod

import six
import torch
from torch.utils.data.dataloader import default_collate
from torchvision.transforms import functional as F
import albumentations as A
import numpy as np
from PIL import Image, ImageOps

from dataiku.core import dkujson
from dataiku.doctor.deephub.deephub_invalid_data_filter import ObjectDetectionInvalidDataFilter
from dataiku.doctor.deephub.deephub_invalid_data_filter import ImageClassificationInvalidDataFilter

"""
Collection of torch datasets that will be consumed by torch models for training/scoring.

A torch dataset is simply a class that implements `__getitem__` and `__len__`. In the context of DSS, the base class
is `dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset`.
The magic to then efficiently consolidate the data into batches is handled by the utility
`torch.utils.data.DataLoader`.

Note that DataLoader will most of the time have multiple workers, that under the hood work with multiprocessing, so
the dataset must be process-safe.
"""

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class DeepHubDataset(object):
    def __getitem__(self, idx):
        try:
            return self._getitem(idx)
        except:
            logger.exception("couldn't get item at index %s", idx)
            return None

    @abstractmethod
    def _getitem(self, idx):
        """
        Returns the item associated with index `idx`
        Error handling is managed by __getitem__
        """
        raise NotImplementedError()

    @abstractmethod
    def __len__(self):
        raise NotImplementedError()

    @staticmethod
    def post_process_model_data(batch_model_data):
        """
        Postprocessing applied to model data, i.e. targets or predictions from a batch fed to the model
        :return: model data reformatted to DSS expected format
        """
        return batch_model_data

    def data_loader_collate_fn(self, batch):
        """
        Returns the function that handles the aggregation of each item into a batch
        If some items are None in the batch, they are discarded
        """
        batch = [b for b in batch if b is not None]

        if len(batch) == 0:
            return None

        return self._data_loader_collate_fn(batch)

    def _data_loader_collate_fn(self, batch):
        """
        Returns the function that handles the aggregation of each item into a batch.
        See https://pytorch.org/docs/stable/data.html#automatic-batching-default for more information
        """
        return default_collate(batch)


class AbstractComputerVisionDataset(DeepHubDataset):

    def __init__(self, files_reader, df, file_path_col, transforms_lists):
        """
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type df: pd.DataFrame
        :type file_path_col: str
        :type transforms_lists: dataiku.doctor.deephub.data_augmentation.image_transformer.ImageTransformsLists
        """
        self.files_reader = files_reader
        self.file_path_col = file_path_col
        self.original_df = df
        self.df = self.filter_invalid_data(df)
        self.transforms_list = transforms_lists

        self.transform_functions = self.create_transform_functions(transforms_lists.with_augmentation_transforms_list)
        logger.info("Applying the following data transforms: {}".format(self.transform_functions))

    def __len__(self):
        return len(self.df)

    def filter_invalid_data(self, df):
        """
        Filters rows of a train dataset to match pytorch required inputs
        :return: the filtered dataframe
        :rtype: pandas.DataFrame
        """
        data_filter = self._get_invalid_data_filter()
        if data_filter is not None:
            return data_filter.filter(df)
        return df

    def _get_invalid_data_filter(self):
        """
        Returns the filter associated to the current dataset
        :return: the dedicated target filter for the dataset type
        :rtype: dataiku.doctor.deephub.deephub_invalid_data_filter.InvalidDataFilter
        """
        return None

    def _getitem(self, idx):
        row = self.get_row(idx)
        img = self.get_image_as_np_array(row)
        original_shape = img.shape  # shape before any transform is applied.
        img = self.transform_functions(image=img)["image"]

        image_id = self.get_image_id(row)
        return {"image_id": torch.tensor(image_id), "original_shape": torch.tensor(original_shape)}, F.to_tensor(img)

    def get_row(self, idx):
        return self.df.iloc[idx]

    def get_image_id(self, row):
        # Index in the original dataframe, before rows were dropped
        return row.name

    def get_image_as_np_array(self, row):
        img_info = row[self.file_path_col]

        with self.files_reader.open_file(img_info) as img_file:
            with Image.open(img_file) as img:
                # if an image has an EXIF Orientation tag, transpose the image accordingly, and remove the orientation data
                img = ImageOps.exif_transpose(img)
                img = img.convert("RGB")

        # Array is of shape (W, H, 3)
        return np.array(img)

    def create_transform_functions(self, transforms_list):
        return A.Compose(transforms_list)


class ImageClassificationDataset(AbstractComputerVisionDataset):

    @staticmethod
    def post_process_model_data(batch_model_data):
        """
        Postprocessing applied to model outputs (by batch) to make them fit the format expected by DSS
        """
        # model outputs scores for all the categories but we need understandable probability to compute
        # metrics and display predictions to the users:
        predicted_probas = torch.nn.Softmax(dim=1)(batch_model_data)
        predicted_categories = torch.argmax(predicted_probas, dim=1).detach().cpu().numpy()

        predicted_probas = predicted_probas.detach().cpu().numpy()
        return predicted_categories, predicted_probas


class ImageClassificationWithTargetDataset(ImageClassificationDataset):
    def __init__(self, files_reader, df, target_remapping, file_path_col, target_col, transforms_lists):
        """
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type df: pd.DataFrame
        :type target_remapping: dataiku.doctor.deephub.deephub_params.TargetRemapping
        :type file_path_col: str
        :type target_col: str
        :type transforms_lists: dataiku.doctor.deephub.data_augmentation.image_transformer.ImageTransformsLists
        """
        self.target_col = target_col
        self.target_remapping = target_remapping
        super(ImageClassificationWithTargetDataset, self).__init__(files_reader, df, file_path_col, transforms_lists)

    def _get_invalid_data_filter(self):
        return ImageClassificationInvalidDataFilter(self.target_col, self.file_path_col,
                                                    set(self.target_remapping.list_categories()))

    def _getitem(self, idx):
        row = self.get_row(idx)

        img_array = self.get_image_as_np_array(row)

        image_infos = {
            "image_id": torch.tensor(self.get_image_id(row)),
            "category": torch.tensor(self.target_remapping.get_category_index(row[self.target_col]), dtype=torch.int64),
            "original_shape": torch.tensor(img_array.shape)
        }
        # apply normalisation & image augmentation:
        img_array = self.transform_functions(image=img_array)["image"]
        return image_infos, F.to_tensor(img_array)


class ObjectDetectionDataset(AbstractComputerVisionDataset):

    @staticmethod
    def to_dss_bbox_format(pytorch_box):
        """ Switch box from Pytorch format [x0, y0, x1, y1] to DSS one (Frankenstein coco) [x0, y0, w, h]
            Note: coordinates origin is top left of image and coordinates increase:
              * from left to right
              * from top to bottom
        """
        return [pytorch_box[0], pytorch_box[1], pytorch_box[2] - pytorch_box[0], pytorch_box[3] - pytorch_box[1]]

    @staticmethod
    def to_pytorch_box_format(dss_bbox):
        """ Switch bbox from DSS format (coco-like) [x0, y0, w, h] to Pytorch one [x0, y0, x1, y1] """
        return [dss_bbox[0], dss_bbox[1], dss_bbox[0] + dss_bbox[2], dss_bbox[1] + dss_bbox[3]]

    @staticmethod
    def post_process_model_data(batch_model_data):
        """
        PyTorch model expects the following format (for its target, and for its outputs):
         for each image, a dict of tensors, each index in each tensor representing information for one box
        On the contrary DSS holds the information as a list of dict, each representing the full information for the box.

        Besides, PyTorch expects:
        * labels to be 1-based for object detection, and 0 is for the background, hence adding 1 to the 0-based DSS
          categories before sending it to the model. Here we re-align the labels with what is expected in
          DSS by:
           * removing 1 to the labels
           * filter out background label from data
           * renaming them "category"
        * boxes to have [x0, y0, x1, y1] format, while DSS expect the COCO bbox format

        :param batch_model_data: tuple of dicts in which each value is a tensor
        """
        batch_annotations = []

        for single_img_data in batch_model_data:

            boxes = single_img_data["boxes"].detach().cpu().numpy()
            labels = single_img_data["labels"].detach().cpu().numpy()

            scores = None
            if "scores" in single_img_data:  # for prediction
                scores = single_img_data["scores"].detach().cpu().numpy()

            single_image_annotations = []
            for ix, label in enumerate(labels):
                # label=0 is background, we don't want it for computing perf
                if label <= 0:
                    continue

                bbox_information = {"bbox": ObjectDetectionDataset.to_dss_bbox_format(boxes[ix, :].tolist()),
                                    # labels are 1-based for object detection, and 0 is for the
                                    # background, re-aligning them to the target_remapping
                                    "category": label - 1}

                if scores is not None:
                    bbox_information["confidence"] = scores[ix]

                single_image_annotations.append(bbox_information)
            batch_annotations.append(single_image_annotations)

        return batch_annotations

    def _getitem(self, idx):
        row = self.get_row(idx)
        img_array = self.get_image_as_np_array(row)
        original_shape = img_array.shape
        img_array = self.transform_functions(image=img_array)["image"]
        image_id = self.get_image_id(row)
        return {"image_id": torch.tensor(image_id), "original_shape": torch.tensor(original_shape)}, F.to_tensor(img_array)

    def _data_loader_collate_fn(self, batch):
        # `collate_fn` argument tells the DataLoader how to build batches from a list of items.
        # By default, it will broadcast every torch tensor across a new dimension and try to do smart things
        # with other types (dict, tuple, ...). This works for most types of data.
        # However, for object detection, each image contains an arbitrary number of bboxes, so they cannot be
        # consolidated easily. Instead, the algorithm needs to receive them one by one, hence the custom `collate_fn`
        # that simply builds a tuple ((image, image, image, ...), (target, target, target, ...))
        return tuple(zip(*batch))


class ObjectDetectionWithTargetDataset(ObjectDetectionDataset):
    MAX_TRY_AUGMENTATION = 10

    def __init__(self, files_reader, df, target_remapping, file_path_col, target_col, transforms_lists):
        """
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type df: pd.DataFrame
        :type target_remapping: dataiku.doctor.deephub.deephub_params.TargetRemapping
        :type file_path_col: str
        :type target_col: str
        :type transforms_lists: dataiku.doctor.deephub.data_augmentation.image_transformer.ImageTransformsLists
        """
        self.target_col = target_col
        self.target_remapping = target_remapping
        super(ObjectDetectionWithTargetDataset, self).__init__(files_reader, df, file_path_col, transforms_lists)

        # Required if some augmentation moves all the boxes out of the image
        self.transform_functions_without_augmentation = self.create_transform_functions(
            transforms_lists.without_augmentation_transforms_list)

    def create_transform_functions(self, transforms_list):
        return A.Compose(transforms_list, A.BboxParams(format='coco', label_fields=["categories"]))

    def _get_invalid_data_filter(self):
        return ObjectDetectionInvalidDataFilter(self.target_col, self.file_path_col,
                                                set(self.target_remapping.list_categories()))

    def _transform(self, img_array, bboxes, categories):
        # For OD, some augmentation (crop, rotate) might push the boxes out of the image, and we can end-up
        # with images without targets. This is not supported by our models.
        # So we use a retry mechanism when this happens, with a TTL that just returns a not-augmented image
        # after MAX_TRY_AUGMENTATION retries without being lucky.

        for _ in range(self.MAX_TRY_AUGMENTATION):
            transformed = self.transform_functions(image=img_array, bboxes=bboxes, categories=categories)

            transformed_img_array = transformed["image"]
            transformed_bboxes = transformed["bboxes"]

            if len(transformed_bboxes) == 0:
                logger.info("Augmentation removed all bboxes, trying again")
                continue

            return transformed_img_array, transformed_bboxes

        logger.info("Maximum number of augmentation trials reached, returning image without augmentation")

        transformed = self.transform_functions_without_augmentation(image=img_array, bboxes=bboxes,
                                                                    categories=categories)
        transformed_img_array = transformed["image"]
        transformed_bboxes = transformed["bboxes"]
        return transformed_img_array, transformed_bboxes

    def _getitem(self, idx):
        row = self.get_row(idx)

        img_array = self.get_image_as_np_array(row)
        original_shape = img_array.shape
        target = dkujson.loads(row[self.target_col])

        bboxes = [a["bbox"] for a in target]  # DSS-coco format
        categories = [self.target_remapping.get_category_index(a["category"]) for a in target]  # DSS-coco format
        iscrowd = [a.get("iscrowd", 0) for a in target]  # no iscrowd flag is treated as `not iscrowd`

        img_array, bboxes = self._transform(img_array, bboxes, categories)

        # We are using Coco format, i.e. each bbox is [x0, y0, w, h] in input, but Pytorch expects [x0, y0, x1, y1]
        # Note: coordinates origin is top left of image and coordinates increase:
        #  * from left to right
        #  * from top to bottom
        area = [b[2] * b[3] for b in bboxes]
        boxes = [self.to_pytorch_box_format(b) for b in bboxes]

        # convert everything into a torch.Tensor
        area = torch.as_tensor(area, dtype=torch.float32)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.float32)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        # PyTorch expects labels to be 1-based for object detection, and 0 is for the background, hence adding 1 to
        # the 0-based DSS categories. Re-alignment to DSS format should be performed after the prediction (in the
        # `post_process_model_data` function).
        labels = torch.as_tensor(categories, dtype=torch.int64) + 1

        image_infos = {
            "boxes": boxes, "labels": labels, "area": area, "iscrowd": iscrowd,
            "image_id": torch.tensor(self.get_image_id(row)),
            "original_shape": torch.tensor(original_shape)
        }
        return image_infos, F.to_tensor(img_array)
