import logging
from abc import abstractmethod, ABCMeta

import pandas as pd
import six
import torch

from dataiku.core import dkujson
from dataiku.doctor.deephub.deephub_scoring import DeepHubScoringEngine
from dataiku.doctor.deephub.deephub_torch_datasets import ObjectDetectionDataset

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class AbstractObjectDetectionScoringEngine(DeepHubScoringEngine):
    TYPE = "DEEP_HUB_IMAGE_OBJECT_DETECTION"

    def __init__(self, file_path_col, target_remapping, scoring_params):
        super(AbstractObjectDetectionScoringEngine, self).__init__(file_path_col, target_remapping, scoring_params)

    def build_dataset(self, df, files_reader, model):
        """
        :type df: pd.DataFrame to be scored
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type model: dataiku.doctor.deephub.deephub_model.ComputerVisionDeepHubModel
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        transforms_lists = model.build_image_transforms_lists()
        return ObjectDetectionDataset(files_reader, df, self.file_path_col, transforms_lists)

    def score(self, model, device, data_loader, deephub_logger, score_explainer):
        pred_list = []
        with score_explainer:
            with torch.no_grad():
                for batch in deephub_logger.iter_over_data(data_loader, redraw_batch_if_empty=False):

                    if batch is None:
                        logger.info("Got an empty batch, not predicting it")
                        continue

                    images_infos, images = batch

                    images = [img.to(device) for img in images]
                    outputs = self.predict_batch(model, images)
                    batch_predictions = self.outputs_to_prediction(outputs, data_loader.dataset.post_process_model_data)

                    pred_list.extend([{"image_id": image_infos["image_id"].detach().cpu().item(), "prediction": image_predictions}
                                      for image_infos, image_predictions in zip(images_infos, batch_predictions)])

        predictions_df = pd.DataFrame(pred_list, columns=["image_id", "prediction"])
        # Some sampling strategy (e.g. distributed) can duplicate images to fill incomplete batches. Keep only 1
        # occurrence of prediction for each image
        predictions_df.drop_duplicates(subset=['image_id'], inplace=True)
        return predictions_df.set_index("image_id")

    def outputs_to_prediction(self, outputs, post_process_model_data_func):
        """
        * Apply de dataset post processing to get list of objects from dict of tensors for each image
        * Convert back category indices to actual categories
        * Keep only predictions with good enough confidence scores
        """
        batch_bboxes = post_process_model_data_func(outputs)

        return [[self.add_category(bbox)
                 for bbox in image_bboxes if bbox["confidence"] >= self.scoring_params["confidence_threshold"]]
                for image_bboxes in batch_bboxes]

    def add_category(self, bbox):
        bbox["category"] = self.target_remapping.get_category(bbox["category"])
        return bbox

    @staticmethod
    def serialize_prediction_df(prediction_df):
        prediction_df["prediction"] = prediction_df["prediction"].apply(dkujson.dumps)

    @abstractmethod
    def predict_batch(self, model, images):
        raise NotImplementedError()


class ObjectDetectionScoringEngine(AbstractObjectDetectionScoringEngine):

    def score(self, model, device, data_loader, deephub_logger, score_explainer):
        model.eval()
        return super(ObjectDetectionScoringEngine, self).score(model, device, data_loader, deephub_logger, score_explainer)

    def predict_batch(self, model, images):
        return model(images)


class DummyObjectDetectionScoringEngine(AbstractObjectDetectionScoringEngine):
    DUMMY = True

    def predict_batch(self, model, images):
        # return empty predictions
        return [{"labels": torch.empty(0, dtype=torch.int64),
                 "boxes": torch.empty((0, 4), dtype=torch.float32),
                 "scores": torch.empty(0, dtype=torch.float32)} for _ in range(len(images))]

    def score(self, model, device, data_loader, deephub_logger, score_explainer):
        logger.info("Dummy scoring, will return empty predictions")
        return super(DummyObjectDetectionScoringEngine, self).score(model, device, data_loader, deephub_logger, score_explainer)
