import logging

import numpy as np
import pandas as pd

from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.deephub.deephub_evaluation import ImageClassificationPerformanceResults
from dataiku.doctor.prediction.classification_scoring import MulticlassModelScorer
from dataiku.doctor.prediction.classification_scoring import format_all_conditional_proba_densities
from dataiku.doctor.prediction.classification_scoring import format_all_proba_densities
from dataiku.doctor.utils import remove_all_nan

logger = logging.getLogger(__name__)


class ImageClassificationPerformanceComputer(object):
    def __init__(self, target_remapping, origin_index, images_ids, targets, probas, predictions):
        """
            Compute multiclass performances given predictions & probabilities from an Image Classification model

        :param target_remapping: Mapping between categories labels & ids. Extracted from the guess sample
        :type target_remapping: TargetRemapping

        :param origin_index: images indexes in the original dataframe (including invalid rows that were possibly dropped
                             by the dataset filter)
        :type origin_index: pd.Index

        :param images_ids: images indexes in the original dataframe (only valid rows remaining),
                           array of shape (nb_valid_images)
        :type images_ids: np.ndarray

        :param targets: ground truth category label for each image,
                        array of shape (nb_valid_images)
        :type targets: np.ndarray

        :param probas: predicted probabilities of each target_remapping's category,
                       array of shape (nb_valid_images, nb_categories)
        :type probas: np.ndarray

        :param predictions: predicted category_id for each image, array  of shape (nb_valid_images)
        :type predictions: np.ndarray
        """
        self.target_remapping = target_remapping
        self.origin_index = origin_index

        self.images_ids = images_ids
        self.targets = pd.Series(targets)
        self.probas = probas
        self.predictions = predictions

    def build_predicted_df(self):
        columns_names = ["proba_{}".format(safe_unicode_str(category)) for category in self.target_remapping.list_categories()]
        # concatenate dataframes instead of ndarrays to preserve dtypes:
        predicted_df = pd.concat([pd.DataFrame({"image_id": self.images_ids, "prediction": self.predictions}),
                                  pd.DataFrame(data=self.probas, columns=columns_names)], axis=1)

        # replace category indices with category names:
        predicted_df["prediction"] = predicted_df["prediction"].apply(self.target_remapping.get_category)

        predicted_df = predicted_df.set_index("image_id")
        predicted_df = predicted_df.reindex(self.origin_index)
        return predicted_df

    def compute_performance(self):
        if len(self.targets) == 0:
            logger.info("No object in data, cannot compute performance")
            return ImageClassificationPerformanceResults.empty(self.target_remapping.list_categories())

        target_map = self.target_remapping.get_category_to_index_map()
        performance_dict = MulticlassModelScorer.compute_multiclass_metrics(self.targets,
                                                                            self.predictions,
                                                                            target_map,
                                                                            self.probas)
        performance_dict["oneVsAllRocAUC"], performance_dict["oneVsAllRocCurves"] = \
            MulticlassModelScorer.get_roc_metrics_and_curves(self.targets, self.probas, target_map)
        performance_dict["oneVsAllAveragePrecision"], performance_dict["oneVsAllPrCurves"] = \
            MulticlassModelScorer.get_pr_metrics_and_curves(self.targets, self.probas, target_map)

        performance_dict["densityData"] = format_all_conditional_proba_densities(self.target_remapping.list_categories(),
                                                                                 target_map, self.probas, self.targets)

        performance_dict["confusion_matrix"] = MulticlassModelScorer.get_multiclass_confusion_matrix(self.targets, self.predictions, self.target_remapping.get_index_to_category_map(), None)
        performance_dict = remove_all_nan(performance_dict)

        prediction_statistics = {
            "probabilityDensities": format_all_proba_densities(self.probas, target_map=target_map),
            "predictions": self.targets.value_counts().to_dict()
        }

        return ImageClassificationPerformanceResults(performance_dict, self.build_predicted_df(), self.predictions,
                                                     prediction_statistics)
