import logging
from abc import abstractmethod, ABCMeta

import pandas as pd
import six
import torch

from dataiku.doctor.deephub.deephub_explaining import NoOpScoreExplainer
from dataiku.doctor.deephub.deephub_logger import DeephubLogger
from dataiku.doctor.deephub.deephub_model import DeepHubModelHandler
from dataiku.doctor.deephub.deephub_model import DeepHubModel
from dataiku.doctor.deephub.utils.deephub_registry import DeepHubRegistry

logger = logging.getLogger(__name__)


class DeepHubScoringHandler(object):
    """
    Responsible for the scoring of Deephub model, mostly receiving a dataframe to score and returning the prediction
    dataframe.

    Note, as of now this handler only works for CPU and 1 GPU, we might add distributed context in the future.
    """

    def __init__(self, params):
        """
        :type params: dataiku.doctor.deephub.deephub_params.DeepHubScoringParams
        """
        self.model_folder_context = params.model_folder_context
        self.model_handler = None
        self.base_model = DeepHubModel.build(params.core_params["prediction_type"], params.target_remapping,
                                             params.modeling_params)
        self.core_params = params.core_params
        self.batch_size = params.batch_size

        self.deephub_logger = DeephubLogger("Scoring")
        self.model_handler = DeepHubModelHandler.build_for_scoring(self.model_folder_context, self.base_model)

        self.engine = DeepHubScoringEngine.build(params.core_params["prediction_type"],
                                                 params.modeling_params,
                                                 params.file_path_column,
                                                 params.target_remapping,
                                                 params.scoring_params)

    def score(self, df, files_reader, with_explanations=False, n_explanations=1):
        """
        :type df: pd.DataFrame
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type with_explanations: bool
        :type n_explanations: int
        :rtype: pd.DataFrame
        """
        dataset = self.engine.build_dataset(df, files_reader, self.base_model)

        if len(dataset) == 0:
            raise Exception("All rows of dataset are dropped, cannot score it")

        data_loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,
                                                  shuffle=False, collate_fn=dataset.data_loader_collate_fn)
        score_explainer = self.engine.build_score_explainer(self.model_handler.deephub_model, self.model_handler.nn_model,
                                                            with_explanations, n_explanations)
        return self.engine.score(self.model_handler.nn_model, self.model_handler.device, data_loader, self.deephub_logger, score_explainer)

    def serialize_prediction_df(self, prediction_df):
        """
        Serialize inplace, if needed, the `prediction_df`, for instance when the prediction is a complex object
        (dictionary) that needs to be converted to a json prior to being saved.

        :param prediction_df: pd.DataFrame
        :return: None
        """
        self.engine.serialize_prediction_df(prediction_df)


@six.add_metaclass(ABCMeta)
class DeepHubScoringEngine(object):
    REGISTRY = DeepHubRegistry()
    DUMMY = False

    def __init__(self, file_path_col, target_remapping, scoring_params):
        self.file_path_col = file_path_col
        self.scoring_params = scoring_params
        self.target_remapping = target_remapping

    @staticmethod
    def define(scoring_class):
        DeepHubScoringEngine.REGISTRY.register(scoring_class.TYPE, scoring_class.DUMMY, scoring_class)

    @staticmethod
    def build(prediction_type, modeling_params, file_path_col, target_remapping, scoring_params):
        """
        :rtype: DeepHubScoringEngine
        """
        dummy = modeling_params.get("dummy", False)
        try:
            scoring_class = DeepHubScoringEngine.REGISTRY.get(prediction_type, dummy)
        except KeyError:
            raise ValueError("Unknown training engine: {} (dummy={})".format(prediction_type, dummy))
        return scoring_class(file_path_col, target_remapping, scoring_params)

    @abstractmethod
    def build_dataset(self, df, files_reader, model):
        """
        :type df: pd.DataFrame to be scored
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :type model: DeepHubModel
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        raise NotImplementedError()

    def build_score_explainer(self, deephub_model, nn_model, with_explanations, n_explanations):
        """
        :type deephub_model: DeepHubModel (contains all the specificities of the model like its architecture and utilities to retrieve parts of the model)
        :type nn_model: torch.nn.Module (contains the current model, with the current weights loaded)
        :type with_explanations: bool
        :type n_explanations: int
        :rtype:  dataiku.doctor.deephub.deephub_scoring.DeepHubScoreExplainer
        """
        return NoOpScoreExplainer(deephub_model, nn_model, with_explanations, n_explanations)

    @abstractmethod
    def score(self, model, device, data_loader, deephub_logger, score_explainer):
        """
        :param model: torch model
        :param device: device on which model is hosted
        :type data_loader: torch.utils.data.DataLoader
        :type deephub_logger: dataiku.doctor.deephub.deephub_logger.DeephubLogger
        :type score_explainer: dataiku.doctor.deephub.deephub_explaining.DeepHubScoreExplainer
        :return: scored data as a dataframe
        :rtype: pd.DataFrame
        """
        raise NotImplementedError()

    @staticmethod
    def serialize_prediction_df(prediction_df):
        """
        Serialize inplace, if needed, the `prediction_df`, for instance when the prediction is a complex object
        (dictionary) that needs to be converted to a json prior to being saved.

        :param prediction_df: pd.DataFrame
        :return: None
        """
        pass
