import logging

from torch import nn

from dataiku.doctor.deephub.data_augmentation.image_transformer import build_image_classification_transforms_lists
from dataiku.doctor.deephub.deephub_model import ComputerVisionDeepHubModel
from dataiku.doctor.deephub.utils.constants import ImageClassificationModels

logger = logging.getLogger(__name__)


class ImageClassificationDeepHubModel(ComputerVisionDeepHubModel):
    TYPE = "DEEP_HUB_IMAGE_CLASSIFICATION"

    def __init__(self, target_remapping, modeling_params):
        super(ImageClassificationDeepHubModel, self).__init__(target_remapping, modeling_params)

    @property
    def model_name(self):
        """
        :rtype: ImageClassificationModels
        """
        if self.modeling_params["pretrainedModel"] not in [e.name for e in ImageClassificationModels]:
            raise RuntimeError("Unsupported image classification model {}".format(self.modeling_params["pretrainedModel"]))

        return ImageClassificationModels[self.modeling_params["pretrainedModel"]]

    def get_model(self, pretrained):
        # Local import is made to ensure OD and image classification can have different version of torchvision library
        from torchvision.models import efficientnet_b0, efficientnet_b4, efficientnet_b7

        num_classes = len(self.target_remapping)

        logger.info("Getting model {}".format(self.model_name))
        if self.model_name == ImageClassificationModels.EFFICIENTNET_B0:
            model_func = efficientnet_b0
        elif self.model_name == ImageClassificationModels.EFFICIENTNET_B4:
            model_func = efficientnet_b4
        elif self.model_name == ImageClassificationModels.EFFICIENTNET_B7:
            model_func = efficientnet_b7
        else:
            raise RuntimeError("Unknown pretrained model for image classification: '{}'"
                               .format(self.modeling_params["pretrainedModel"]))

        model = model_func(pretrained=pretrained)

        # Freezing all the layers's params by default
        for param in model.parameters():
            param.requires_grad = False

        # Replace last fully connected layer (classifier) with a new one, because it probably does not have
        # the same number of classes as the pre-trained model (which has the 1000 ImageNet classes))
        # New parameters have `.requires_grad=True` by default and as we froze all the other layers,
        # only the new layer's parameters (+ those explicitly unfrozen) will be updated during training.

        # Note: we use Pytorch default way of initializing weights & bias as it seems that EfficientNet way of
        # initializing degrades perf highly as of today.
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        layers_to_retrain = self.get_number_of_retrained_layers() if pretrained else 0  # for scoring (!pretrained) we don't want to unfreeze any layer

        trainable_layers = len(model.features) - 1  # 1st stem block is not trainable
        if layers_to_retrain > trainable_layers:
            raise Exception("too many layers to retrain, only {} existing, got {}".format(trainable_layers,
                                                                                          layers_to_retrain))
        if layers_to_retrain >= 1:
            for param in model.features[-layers_to_retrain:].parameters():
                param.requires_grad = True  # unfreeze all the parameters of these pretrained layers

        # Batchnorm layers should always stay frozen
        # (See to Keras recommendations for Finetuning: https://keras.io/guides/transfer_learning/)
        for layer in model.modules():
            if isinstance(layer, nn.modules.batchnorm.BatchNorm2d):
                for param in layer.parameters():
                    param.requires_grad = False

        # Note: at this point the model contains pretrained weights or random ones (depending on pretrained boolean) from the initial training (from torch),
        # except for the head which always contains randomly initialized weights.
        return model

    @property
    def model_input_size(self):
        """ Return size of the input image for the current model.
            Note: Most of image classification models inputs 224*224 images but efficientNet B1-B7 belongs to exceptions
        """
        if self.model_name == ImageClassificationModels.EFFICIENTNET_B4:
            return {"width": 380, "height": 380}
        if self.model_name == ImageClassificationModels.EFFICIENTNET_B7:
            return {"width": 600, "height": 600}
        return {"width": 224, "height": 224}

    @property
    def model_normalisation(self):
        """ Return pixel normalisation based on mean/Std values of the training set used during pretraining.
            Most of well-known image classification models are trained on ImageNet but there could be exceptions.
        """
        return {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}

    def build_image_transforms_lists(self, augmentation_params=None):
        return build_image_classification_transforms_lists(self.model_input_size, self.model_normalisation,
                                                           augmentation_params)

    def get_resolved_params(self):
        return {"retrainedLayers": self.get_number_of_retrained_layers()}

    def get_last_layer_module(self, nn_model):
        return nn_model.features[-1]

    def get_classifier_parameters(self, nn_model):
        return list(nn_model.classifier[1].parameters())[0].detach().cpu().numpy()


class DummyImageClassificationDeepHubModel(ImageClassificationDeepHubModel):
    DUMMY = True

    @property
    def model_name(self):
        return "Dummy Classification Model"

    def get_model(self, pretrained):
        logger.info("Getting dummy model")

    def get_number_of_retrained_layers(self):
        return 0

    def get_last_layer_module(self, nn_model):
        return None

    def get_classifier_parameters(self, nn_model):
        return None