import logging
from copy import deepcopy

import numpy as np
import pandas as pd
from scipy.stats import entropy


from dataiku.core import doctor_constants
from dataiku.base.utils import safe_unicode_str
from dataiku.core import dkujson
from dataiku.doctor.prediction.background_rows_handler import BackgroundRowsHandler
from dataiku.doctor.prediction.column_importance_handler import ColumnImportanceHandler
from dataiku.doctor.posttraining.features_distribution import FeaturesDistributionHandler
from dataiku.doctor.posttraining.features_distribution import NumericFeatureDistributionComputer
from dataiku.doctor.prediction.common import PredictionAlgorithmSparseSupport
from dataiku.doctor.prediction.explanations.engine import PeakCellCountBatchingStrategy
from dataiku.doctor.prediction.explanations.engine import BatchingExplainingEngine
from dataiku.doctor.prediction.explanations.engine import FixedSizeBatchingStrategy
from dataiku.doctor.prediction.explanations.ice import ICEExplainingEngine
from dataiku.doctor.prediction.explanations.shapley import DefaultShapleyFrankensteinScoresBuilder
from dataiku.doctor.prediction.explanations.shapley import PreprocessShapleyFrankensteinScoresBuilder
from dataiku.doctor.prediction.explanations.shapley import ShapleyExplainingEngine
from dataiku.doctor.prediction.explanations.engine import ScoreComputer
from dataiku.doctor.prediction.explanations.shapley import ShapleyExplanationsExtractor
from dataiku.doctor.prediction.explanations.shapley import ShapleyIndividualExplanationsExtractor
from dataiku.doctor.prediction.explanations.shapley import ShapleyGlobalExplanationsExtractor
from dataiku.doctor.prediction.explanations.shapley import MulticlassShapleyGlobalExplanationsExtractor
from dataiku.doctor.prediction.explanations.score_to_explain import ScoreToExplain
from dataiku.doctor.prediction.explanations.score_to_explain import MulticlassScoreToExplain
from dataiku.doctor.prediction.explanations.score_to_explain import OneDimensionScoreToExplain
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.utils import normalize_dataframe, dku_nonan
from dataiku.doctor.utils.metrics import log_odds
from dataiku.doctor.utils.split import df_from_split_desc_no_normalization
from dataiku.doctor.utils.split import input_columns

RANDOM_SEED = 1337
DEFAULT_SHAPLEY_BACKGROUND_SIZE = 100
DEFAULT_SUB_CHUNK_SIZE = 10000
DEFAULT_NB_EXPLANATIONS = 3
MIN_NB_DISTINCT_FOR_QUANTILES = 10
MIN_NB_MODALITIES_TO_KEEP = 10
MAX_NB_MODALITIES_TO_KEEP = 25

# Choose the number of modalities to keep depending on whether the
# sum of bars of most frequent modalities is greater than this number.
HISTOGRAM_SIGNIFICANCE_THRESHOLD = .9

logger = logging.getLogger(__name__)


class ExplanationMethod:
    ICE = "ICE"
    SHAPLEY = "SHAPLEY"


class IndividualExplainer:
    """ Computes prediction & per-row explanations of those predictions
    For that there are two methods:
    - ICE: it takes the difference between the prediction of one given example x and (an approximation of)
    the expectation of the predictions of the examples x created by replacing one specific feature in x by all
    its potential values, against the marginal distribution of this feature
    - Shapley values: it computes an estimation of the average impact on the prediction of switching a feature value
    from the value it takes in a random sample (background rows) to the value it takes in the sample to be explained
    while a random number of feature values have already been switched in the same way.

    Those methods use specific names like:
    - frankenstein: modified version of the original observations.
    - scores: the prediction in regression and the log-odd of the proba(s) in classification
    - modalities: unique values of a column for a categorical/text feature or bins for numerical one
    - background rows: sample of the train set used to get some feature values
    to tweak the observations and build the frankenstein
    """

    def __init__(self, predictor, model_folder_context, preprocessing_folder_context, train_split_desc, split_folder_context, per_feature, is_ensemble,
                 prediction_type, sample_weight_col=None):

        self._train_split_desc = train_split_desc
        self._split_folder_context = split_folder_context
        self._is_kfolding = train_split_desc["params"].get("kfold", False)
        self._prediction_type = prediction_type
        self._per_feature = per_feature
        self._is_ensemble = is_ensemble
        self._predictor = predictor

        self.trainset = None
        self.not_normalized_trainset = None
        self.trainset_prediction_information = None

        self.background_rows = None
        self.features_distribution = None
        self.column_importance = None

        self.sample_weight_col = sample_weight_col

        self.column_importance_handler = ColumnImportanceHandler(model_folder_context, preprocessing_folder_context)
        self.column_importance_compute_has_failed = False

        self.background_rows_handler = BackgroundRowsHandler(model_folder_context,
                                                             self._train_split_desc,
                                                             self._prediction_type,
                                                             self._per_feature)

        self.distribution_computer = FeaturesDistributionHandler(model_folder_context)

        self._input_columns = input_columns(per_feature)

    @staticmethod
    def _build_individual_explanations_score_computer_composer(class_idx):
        """
        Build the relevant score computer for individual computation depending on the context.

        Note: this is only relevant for multiclass use-case, otherwise this will always use the score outputted from
        `score_computer`.

        We might realign the score in the multiclass case, depending on:
         * whether 'class_idx' is defined. If defined, no need to realign
         * whether `other_score_to_align_with` is defined

        :param str or None class_idx: [Optional] When computing individual explanations for multiclass for a specific
                                      class, this corresponds to the index of said class
        :rtype: ScoreComputer
        """
        def score_computer_composer(score_computer):
            def _individual_explanations_score_computer(data, other_score_to_align_with=None,
                                                        matching_indices_in_other=None):
                initial_score = score_computer(data)
                if not isinstance(initial_score, MulticlassScoreToExplain):
                    return initial_score
                if class_idx is not None:
                    return OneDimensionScoreToExplain(initial_score.per_class_score[:, class_idx])
                else:
                    if other_score_to_align_with is None:
                        return initial_score
                    assert isinstance(other_score_to_align_with, MulticlassScoreToExplain)
                    return MulticlassScoreToExplain.build_from_other_score_to_explain(initial_score.per_class_score,
                                                                                      other_score_to_align_with,
                                                                                      matching_indices_in_other)
            return _individual_explanations_score_computer
        return score_computer_composer

    def _build_simple_individual_explanations_score_computer(self, class_idx):
        return self._build_individual_explanations_score_computer_composer(class_idx)(
            lambda data, _1=None, _2=None: self._get_prediction_information(data).score_to_explain)

    def _supports_shapley_from_preprocess(self, features_to_column_indices):
        if self._is_ensemble:
            logger.info("Ensemble model does not support Shapley from preprocess")
            return False
        if features_to_column_indices is None:
            logger.info("Features to column indices is not supported. Shapley from preprocessed is not supported")
            return False
        if not hasattr(self._predictor, "_model") or not isinstance(self._predictor._model, ScorableModel):
            logger.info("Model is not of the right format,Shapley from preprocessed is not supported")
            return False
        if self._predictor._model.requires_unprocessed_df_for_prediction():
            logger.info("Model requires input df for prediction, Shapley from preprocessed is not supported")
            return False
        if self._may_use_sparse_matrix():
            logger.info("Model may use sparse matrices, Shapley from preprocessed is not supported")
            return False
        return True

    def _may_use_sparse_matrix(self):
        sparse_support = PredictionAlgorithmSparseSupport(self._predictor.params.modeling_params)
        return sparse_support.supports_csr() and sparse_support.should_allow_sparse_matrices()

    def _get_shapley_explaining_engine(self, shapley_background_size, columns_to_explain,
                                       explanations_extractor, score_computer_composer=None):
        """
        :type shapley_background_size: int
        :type columns_to_explain: list[str]
        :type explanations_extractor: ShapleyExplanationsExtractor
        :param score_computer_composer: [Optional] wrapper on top of score computer that outputs another score computer
        :rtype: ShapleyExplainingEngine
        """
        background_rows = self._get_background_rows(shapley_background_size)
        features_to_column_indices = self._get_features_to_column_indices_mapping_or_none(background_rows.columns)
        if not self._supports_shapley_from_preprocess(features_to_column_indices):
            logger.info("Using default Shapley Engine, all mentions of cells are *prior* preprocessing")

            def orig_score_computer(data, other_score_to_align_with=None, matching_indices_in_other=None):
                return self._get_prediction_information(data).score_to_explain  # By default, we are not realigning

            score_computer = (orig_score_computer if score_computer_composer is None
                              else score_computer_composer(orig_score_computer))
            scores_builder = DefaultShapleyFrankensteinScoresBuilder(background_rows, columns_to_explain, score_computer)
            return ShapleyExplainingEngine(scores_builder, explanations_extractor)
        else:
            logger.info("Using Shapley from preprocessed engine, all mentions of cells *include* the preprocessing")

            def orig_score_computer(data, other_score_to_align_with=None, matching_indices_in_other=None):
                return self._get_score_from_preprocessed(data)  # by default, we are not realigning

            score_computer = (orig_score_computer if score_computer_composer is None
                              else score_computer_composer(orig_score_computer))
            scores_builder = PreprocessShapleyFrankensteinScoresBuilder(
                background_rows, columns_to_explain,
                lambda df: self._predictor.preprocessing.preprocess(df)[0],
                score_computer,
                features_to_column_indices)
            return ShapleyExplainingEngine(scores_builder, explanations_extractor)

    def _get_individual_shapley_explaining_engine(self, columns_to_explain, shapley_background_size, class_idx):
        score_computer_composer = self._build_individual_explanations_score_computer_composer(class_idx)
        return self._get_shapley_explaining_engine(shapley_background_size, columns_to_explain,
                                                   ShapleyIndividualExplanationsExtractor(),
                                                   score_computer_composer)

    def _get_global_shapley_explaining_engine(self, columns_to_explain, shapley_background_size):
        explanation_extractor = (
            MulticlassShapleyGlobalExplanationsExtractor(self._predictor.classes)
            if self._prediction_type == "MULTICLASS" else
            ShapleyGlobalExplanationsExtractor()
        )
        return self._get_shapley_explaining_engine(shapley_background_size, columns_to_explain, explanation_extractor)

    def _get_ice_explaining_engine(self, columns_to_explain, observations_df, class_idx):
        return ICEExplainingEngine(self.features_distribution, columns_to_explain, observations_df.dtypes,
                                   self._build_simple_individual_explanations_score_computer(class_idx))

    def _get_background_rows(self, max_background_size):
        """
        :type max_background_size: int
        :rtype: pd.DataFrame
        """
        nb_rows = self.background_rows.shape[0]
        if self.background_rows.shape[0] < max_background_size:
            background_size = nb_rows
            logger.info("Not enough rows, lowering Monte Carlo steps to {}".format(nb_rows))
        else:
            background_size = max_background_size
        return self.background_rows.head(background_size)

    def explain(self, observations_df, nb_explanations, method, for_class=None, debug_mode=False, progress=None,
                sub_chunk_size=None, shapley_background_size=DEFAULT_SHAPLEY_BACKGROUND_SIZE):
        """ Compute the explanations for each row in observations.
            Note: this method *expects* that no rows from `observations_df` will be dropped during preprocessing
            :param observations_df: rows to explain
            :type observations_df: pd.DataFrame
            :param nb_explanations: Number of explanations the user wants,
                more will be computed to be sure we don't miss too much
            :type nb_explanations: int
            :param method: Method to compute the explanation
            :type method: ExplanationMethod.ICE or ExplanationMethod.SHAPLEY
            :param for_class: in multiclass, a class can be provided to compute in one vs all mode
            :type for_class: str
            :param debug_mode: If False, silence pre-processing pipeline logs
            :type debug_mode: bool
            :param progress: Object to refresh the progress bar in the UI. Should be None if explanations are not done through the UI
            :type progress: None or PercentageProgress
            :param sub_chunk_size: into how much chunks the dataset to explain should be divided in chunks (to prevent OOM errors)
            :type sub_chunk_size: int
            :param shapley_background_size: Size of the background to use with Shapley method
            :type shapley_background_size: int

            :return: the explanations with the same shape as the observations
            :rtype pd.DataFrame
        """
        if not self.is_ready():
            raise Exception("Explainer isn't ready")

        if observations_df.shape[0] == 0:
            logger.info("Explaining empty dataframe, returning empty explanations")
            return pd.DataFrame(columns=self._input_columns, index=observations_df.index)

        observations_df = observations_df[self._input_columns]

        columns_to_compute = self._get_most_important_columns(nb_explanations)
        class_idx = self._predictor.classes.index(for_class) if for_class is not None else None

        batching_strategy = FixedSizeBatchingStrategy(sub_chunk_size)
        if method == ExplanationMethod.SHAPLEY:
            explaining_engine = self._get_individual_shapley_explaining_engine(columns_to_compute,
                                                                               shapley_background_size, class_idx)
        elif method == ExplanationMethod.ICE:
            explaining_engine = self._get_ice_explaining_engine(columns_to_compute, observations_df, class_idx)
        else:
            raise ValueError("Unknown method to explain prediction '{}'".format(method))

        engine = BatchingExplainingEngine(explaining_engine, batching_strategy, progress, debug_mode)
        explaining_result = engine.explain(observations_df)
        return explaining_result.explanations_df

    def explain_global(self, observations_df, nb_explanations, progress=None, shapley_background_size=DEFAULT_SHAPLEY_BACKGROUND_SIZE):
        """ Compute the explanations for each row in observations & aggregate them to get global explanations.
            Note: this method *expects* that no rows from `observations_df` will be dropped during preprocessing
            :param observations_df: rows to explain
            :type observations_df: pd.DataFrame
            :param nb_explanations: Number of explanations the user wants, more will be computed to be sure we don't miss too much
            :type nb_explanations: int
            :param progress: Object to refresh the progress bar in the UI. Should be None if explanations are not done through the UI
            :type progress: None or PercentageProgress
            :param shapley_background_size: Size of the background to use with Shapley method
            :type shapley_background_size: int

            :return: absolute explanations & explanations
            :rtype tuple(dictionary, pd.DataFrame)
        """
        if not self.is_ready():
            raise Exception("Explainer isn't ready")
        observations_df = observations_df[self._input_columns]
        columns_to_compute = self._get_most_important_columns(nb_explanations)
        explaining_engine = self._get_global_shapley_explaining_engine(columns_to_compute, shapley_background_size)
        batching_strategy = PeakCellCountBatchingStrategy(
            15, 300000, explaining_engine.get_estimated_peak_number_cells_generated_per_row_explained())
        engine = BatchingExplainingEngine(explaining_engine, batching_strategy, progress, False)
        explaining_result = engine.explain(observations_df[self._input_columns])
        return explaining_result.to_dicts()

    def compute_stats_on_shapley_and_feature_values(self, feature_name, raw_importances_df):
        feat_values = raw_importances_df[feature_name]
        shap_values = raw_importances_df['shapley_' + feature_name]
        pos = shap_values >= 0
        neg = shap_values < 0
        if self._per_feature[feature_name]["type"] == "NUMERIC":
            # numeric
            feat_values = pd.to_numeric(feat_values)
            feat_values_pos = feat_values[pos]
            feat_values_neg = feat_values[neg]
            mom_pos = feat_values_pos.mean()
            mom_neg = feat_values_neg.mean()
            corr = pd.concat([shap_values, feat_values], axis=1).corr(method='kendall').loc[feature_name, 'shapley_' + feature_name]
            info_gain = np.nan
        else:  # ["CATEGORY", "TEXT", "VECTOR"]:
            pos_mode = feat_values[pos].mode()
            neg_mode = feat_values[neg].mode()
            mom_pos = pos_mode.iloc[0] if len(pos_mode) else np.nan
            mom_neg = neg_mode.iloc[0] if len(neg_mode) else np.nan
            corr = np.nan
            # information gain
            # children nodes are split by category
            # classes are positive/negative shap values
            ent_parent = entropy((shap_values > 0).value_counts(normalize=True, sort=False))
            ent_children = 0
            for k in feat_values.unique():
                ent_children += entropy((shap_values[feat_values == k] > 0).value_counts(normalize=True, sort=False))
            ent_children /= len(feat_values.unique())

            info_gain = ent_parent - ent_children
        stats = {
            "feature_name": feature_name,
            "sh_mean": shap_values.abs().mean(),
            "sh_mean_pos": shap_values[pos].mean(),
            "sh_mean_neg": shap_values[neg].mean(),
            "mom_pos": mom_pos,
            "mom_neg": mom_neg,
            "corr": corr,
            "info_gain": info_gain
        }

        return stats

    def get_model_per_class_facts(self, raw_explanations, observations, corr_thr=0.7, info_gain_thr=0.5):
        per_class_facts = {}

        for klass, features in raw_explanations.items():
            # We iterate over all classes. We try to get 1 fact / (feature, class)
            explanations_and_observations = deepcopy(raw_explanations[klass])
            # Rename the shapley values dict to properly construct the dataframe
            for feature in features:
                explanations_and_observations['shapley_{}'.format(feature)] = explanations_and_observations.pop(feature)
                explanations_and_observations[feature] = observations[feature]
            df = pd.DataFrame.from_dict(explanations_and_observations)
            per_class_facts[klass] = {}
            list_stats = [self.compute_stats_on_shapley_and_feature_values(feature, df) for feature, feature_shapley_values in features.items()]
            stats_df = pd.DataFrame.from_dict(list_stats).sort_values('sh_mean', ascending=False).reset_index(drop=True)
            for (idx, row) in stats_df.iterrows():  # Rows are sorting according to the ranking
                # Now we are trying to create a fact for every feature in the class
                feature_name = row["feature_name"]
                fact = {
                    "impact": row.sh_mean,
                    "rank": idx + 1,
                    "feature": feature_name,
                }
                if self._per_feature[feature_name]["type"] == "NUMERIC":
                    # numeric feature
                    if np.isnan(row['corr']) or np.abs(row['corr']) < corr_thr:
                        continue
                    fact["correlation"] = row["corr"]
                    fact["type"] = "NUMERICAL"
                else:
                    # categorical feature
                    if np.isnan(row.info_gain) or row.info_gain < info_gain_thr:
                        continue
                    fact["positiveCategoricalValue"] = row.mom_pos
                    fact["negativeCategoricalValue"] = row.mom_neg
                    fact["infoGain"] = row.info_gain
                    fact["type"] = "CATEGORICAL"
                for k in fact.keys():
                    fact[k] = dku_nonan(fact[k])
                per_class_facts[klass][feature_name] = fact
        return per_class_facts

    def format_explanations(self, explanations_df, nb_explanations, with_json=False):
        """
        Format explanations, keeping only the top `nb_explanations` per row

        Example:

          for the following explanation dataframe:

               Embarked     Sex      Age     Fare
            0  -0.14513 -0.5972  0.15309 -0.36280
            1  -0.14513 -0.5972  0.15309 -0.36280
            2   0.44099 -0.5972 -0.08990  0.29136
            3  -0.14513  1.0996  0.09234 -0.35590
            4  -0.14513 -0.5972 -0.04434  0.10560

          with 2 explanations, will yield:

           if with_json is True, the pd.Series:

                0         {"Fare": -0.3628, "Sex": -0.5972}
                1         {"Fare": -0.3628, "Sex": -0.5972}
                2     {"Embarked": 0.44099, "Sex": -0.5972}
                3          {"Fare": -0.3559, "Sex": 1.0996}
                4    {"Embarked": -0.14513, "Sex": -0.5972}

           if with_json is False, the pd.DataFrame:

                   Embarked     Sex  Age    Fare
                0       NaN -0.5972  NaN -0.3628
                1       NaN -0.5972  NaN -0.3628
                2   0.44099 -0.5972  NaN     NaN
                3       NaN  1.0996  NaN -0.3559
                4  -0.14513 -0.5972  NaN     NaN

        :param explanations_df: dataframe of explanations (obtained via self.explain(...))
        :type explanations_df: pd.DataFrame
        :param nb_explanations: number of explanations to output per row
        :type nb_explanations: int
        :param with_json: whether to output explanations as json or as a dataframe
        :type with_json: bool
        :return: pd.Series | pd.DataFrame
        """

        logger.info("Formatting most important explanations")
        nb_explanations = min(nb_explanations, len(explanations_df.columns))
        if with_json:
            top_explanations_indices = np.argpartition(np.abs(explanations_df.values),
                                                       -nb_explanations, axis=1)[:, -nb_explanations:]
            top_explanations_values = np.take_along_axis(explanations_df.values, top_explanations_indices, axis=1)

            top_explanations_cols = np.empty(top_explanations_values.shape, dtype="object")
            for index, column in enumerate(explanations_df.columns):
                top_explanations_cols[top_explanations_indices == index] = column

            top_explanations_list = list(np.dstack((top_explanations_cols, top_explanations_values)))
            return pd.Series(top_explanations_list).apply(lambda x: dkujson.dumps(dict(x), ensure_ascii=False))
        else:
            formatted_explanations_df = explanations_df.copy()
            bottom_explanations_indices = np.argpartition(np.abs(formatted_explanations_df.values),
                                                          -nb_explanations, axis=1)[:, :-nb_explanations]
            np.put_along_axis(formatted_explanations_df.values, bottom_explanations_indices, np.nan, axis=1)
            return formatted_explanations_df

    def _get_most_important_columns(self, n_explanations):
        """
        :param n_explanations: number of columns to explain
        :return: The columns to explain
        """
        max_n_col = n_explanations * 5
        if max_n_col > len(self._input_columns) or self.column_importance is None:
            return self._input_columns
        else:
            # Use mergesort to get deterministic results in case of ties
            self.column_importance = self.column_importance.sort_values(by="importances", ascending=False, kind='mergesort')
            most_important_columns = self.column_importance["columns"].values[:max_n_col]
            logger.info("To reduce computation time, computing explanations "
                        "for the {} most important columns: {}".format(max_n_col, most_important_columns))
            return most_important_columns

    def _get_not_normalized_train_set(self):
        """
        :rtype: pd.DataFrame
        """
        if self.not_normalized_trainset is None:
            self.not_normalized_trainset = df_from_split_desc_no_normalization(
                self._train_split_desc,
                "full" if self._is_kfolding else "train",
                self._split_folder_context,
                self._per_feature,
                self._prediction_type
            )
        return self.not_normalized_trainset

    def _get_train_set(self):
        """
        :rtype: pd.DataFrame
        """
        if self.trainset is None:
            self.trainset = normalize_dataframe(self._get_not_normalized_train_set().copy(), self._per_feature)
        return self.trainset

    def _needs_to_read_train_set(self):
        """
        Whether we will need to use data from the original SM train dataset.
        Based on the code paths in self._load_or_compute_column_importance, self._load_or_draw_background_rows and self._load_or_compute_quantiles
        :rtype: bool
        """
        if (self.column_importance_handler.has_saved_column_importance()
                and self.background_rows_handler.has_saved_background_rows()
                and self.distribution_computer.has_saved_features_distribution()):
            logger.info("Skipping load of train dataset for computation of explanations, required data already present.")
            return False

        logger.info("Loading train dataset to allow for computation of explanations.")
        return True

    def make_ready(self, trainset_override=None, save=False):
        """
        Load / compute everything necessary for explanations: background rows, column importance, feature histograms
        :param trainset_override: (optional) Dataset to use to compute those. It will init the explainer with the given dataset and not read possibly available data on disk.
        It will also prevent the explainer from persisting the data for this dataset.
        :type trainset_override: pd.DataFrame
        :param save: Persist the computed data on disk
        :type save: bool
        """
        logger.info("Preparing explainer components: column importance, background rows and feature histograms")
        if trainset_override is not None:
            if trainset_override.shape[0] < BackgroundRowsHandler.MIN_BACKGROUND_SIZE:
                raise Exception("Not enough rows in input dataset to calculate explanations")
            trainset_override = trainset_override.copy()
            trainset_prediction_information = self._get_prediction_information(trainset_override.copy())
        elif self._needs_to_read_train_set():
            trainset_prediction_information = self._get_or_compute_trainset_prediction_information()
        else:
            trainset_prediction_information = None

        self._load_or_compute_column_importance(trainset_override is not None,
                                                trainset_prediction_information,
                                                save=save)
        self._load_or_draw_background_rows(trainset_prediction_information, trainset_override=trainset_override,
                                           save=save)
        try:
            self._load_or_compute_quantiles(trainset_prediction_information, trainset_override, save=save)
        except Exception:
            logger.exception("Could not compute distribution histograms")

    def is_ready(self):
        """
        :return: Whether the explainer has its backing data loaded to perform explanations
        """
        return self.background_rows is not None and self.features_distribution is not None

    def _load_or_compute_column_importance(self, force_compute_importance, df_prediction_information, save=False):
        """ Retrieve the model's column importance or build one
        :param df_prediction_information: prediction information of the data
        :type df_prediction_information: PredictionInformation
        :param force_compute_importance: Whether to force the compute of column importance, no matter if saved results
                                         are available
        :type force_compute_importance: bool
        :param save: whether to persist the column importance on disk
        :type save: bool
        """
        if not force_compute_importance and self.column_importance_handler.has_saved_column_importance():
            logger.info("Fetching column importance from model")
            self.column_importance = self.column_importance_handler.get_column_importance()
            # For model trained before 8.0.3, column importance can contain non-input columns (ch54832)
            self.column_importance = self.column_importance[self.column_importance["columns"].isin(self._input_columns)]
        elif not self.column_importance_compute_has_failed:
            scores_a = (df_prediction_information.score_to_explain.score
                        if isinstance(df_prediction_information.score_to_explain, OneDimensionScoreToExplain)
                        else df_prediction_information.score_to_explain.per_class_score)
            try:
                if self._is_ensemble:
                    raise ValueError("Column importance incompatible with ensembling models")
                self.column_importance = self.column_importance_handler.compute_column_importance(
                    list(self._input_columns),
                    df_prediction_information.features,
                    df_prediction_information.transformed_a,
                    scores_a,
                    self._predictor.preprocessing.pipeline.generated_features_mapping,
                    save=save)
            except Exception as e:
                self.column_importance_compute_has_failed = True
                logger.exception("Could not optimize the number of columns to explain: {}".format(e))

    def _load_or_draw_background_rows(self, df_prediction_information, trainset_override=None, save=False):
        """ Retrieve the model's background rows or draw them.
        :param df_prediction_information: prediction information of the data
        :type df_prediction_information: PredictionInformation
        :param trainset_override: data from which the rows must be drawn. Pass only if not the train time trainset
        :type trainset_override: pd.DataFrame
        :param save: whether to persist the background rows on disk
        :type save: bool
        """
        if trainset_override is None and self.background_rows_handler.has_saved_background_rows():
            self.background_rows = self.background_rows_handler.retrieve_background_rows()
        else:
            if trainset_override is not None:
                trainset_df = trainset_override
            else:
                # retrieving the not-normalized data because might be serialized
                trainset_df = self._get_not_normalized_train_set()

            # Only keeping not dropped rows, to be aligned with the scored data
            trainset_df = trainset_df.loc[df_prediction_information.index]

            if trainset_df.shape[0] < BackgroundRowsHandler.MIN_BACKGROUND_SIZE:
                logger.info(
                    "Dataset too small to draw background rows. Shapley feature importance will be unavailable.")
                return
            else:
                score_for_background = (df_prediction_information.score_to_explain.per_class_score
                                        if isinstance(df_prediction_information.score_to_explain,
                                                      MulticlassScoreToExplain)
                                        else df_prediction_information.score_to_explain.score)
                self.background_rows = self.background_rows_handler.draw_background_rows(trainset_df,
                                                                                         score_for_background,
                                                                                         save=save)
                if trainset_override is None:
                    self.background_rows = normalize_dataframe(self.background_rows, self._per_feature)
        self.background_rows = self.background_rows[self._input_columns]

    def _load_or_compute_quantiles(self, df_prediction_information, trainset_override=None, save=False):
        """ Load and build if needed the columns histograms.
        :param df_prediction_information: prediction information of the data
        :type df_prediction_information: PredictionInformation
        :param trainset_override: data from which the histograms must be computed. Pass it only if not the train time train set
        :type trainset_override: pd.DataFrame
        :param save: whether to persist the features distribution on disk
        :type save: bool
        """
        if trainset_override is None and self.distribution_computer.has_saved_features_distribution():
            self._set_features_distribution(self.distribution_computer.load())
        else:
            if trainset_override is not None:
                trainset_df = trainset_override
            else:
                trainset_df = self._get_train_set()

            # Make sure to only keep not dropped rows
            trainset_df = trainset_df.loc[df_prediction_information.index]

            # Sample weights
            sample_weights = None
            if self.sample_weight_col is not None:
                if self.sample_weight_col in trainset_df:
                    sample_weights = trainset_df[self.sample_weight_col]
                else:
                    logger.warning(u"Sample weight column ('{}') missing, ignoring sample weights for "
                                   u"distribution computation".format(safe_unicode_str(self.sample_weight_col)))
            features_distribution = self.distribution_computer.compute_all(trainset_df, self._per_feature,
                                                                           save=save,
                                                                           sample_weight=sample_weights)
            self._set_features_distribution(features_distribution)

    def _set_features_distribution(self, features_distribution):
        self.features_distribution = {}
        for feature, feature_distribution in features_distribution.items():
            if isinstance(feature_distribution, NumericFeatureDistributionComputer):
                if feature_distribution.nb_distinct > MIN_NB_DISTINCT_FOR_QUANTILES:
                    scale, distribution = feature_distribution.get_quantiles_with_nans()
                    self.features_distribution[feature] = {
                        "scale": scale,
                        "distribution": distribution
                    }
                else:
                    scale, counts = feature_distribution.get_top_values_with_nans()
                    self.features_distribution[feature] = {
                        "scale": scale,
                        "distribution": counts / np.sum(counts)
                    }
            else:
                scale, distribution = feature_distribution.get_values_with_nans()
                self.features_distribution[feature] = {
                    "scale": scale,
                    "distribution": distribution
                }
                indices = np.argsort(-distribution)
                # Keep the most frequent
                if distribution[indices[:MIN_NB_MODALITIES_TO_KEEP]].sum() >= HISTOGRAM_SIGNIFICANCE_THRESHOLD:
                    nb_modalities_to_keep = MIN_NB_MODALITIES_TO_KEEP
                else:
                    nb_modalities_to_keep = MAX_NB_MODALITIES_TO_KEEP
                kept_indices = indices[:nb_modalities_to_keep]
                self.features_distribution[feature]["distribution"] = distribution[kept_indices] / distribution[kept_indices].sum()
                self.features_distribution[feature]["scale"] = self.features_distribution[feature]["scale"][kept_indices]

    def _get_or_compute_trainset_prediction_information(self):
        if self.trainset_prediction_information is None:
            self.trainset_prediction_information = self._get_prediction_information(self._get_train_set().copy())
        return self.trainset_prediction_information

    def _get_score_from_preprocessed(self, preprocessed_a):
        """
        Warning: this method makes a *lot* of assumptions on the underlying data & model, be very careful when using it!
        :type preprocessed_a: np.ndarray
        :rtype: ScoreToExplain
        """
        fake_index = np.arange(preprocessed_a.shape[0])  # No need for real index here, will not be used
        predicted_df = self._predictor._predict_preprocessed(preprocessed_a, None, fake_index, True, True)
        return self._get_score_from_pred_df(predicted_df)

    def _get_prediction_information(self, observations_df):
        """ Predict the observations, return its PredictionInformation
        :param observations_df: the rows to predict
        :type observations_df: pd.DataFrame
        :return: the prediction information
        :rtype PredictionInformation
        """
        transformed_a, predicted_df = self._predict_and_get_transformed_df(observations_df)
        score_to_explain = self._get_score_from_pred_df(predicted_df)
        return self.PredictionInformation(observations_df.index, score_to_explain, transformed_a,
                                          self._predictor.features)

    def _get_score_from_pred_df(self, pred_df):
        """
        :type pred_df: pd.DataFrame
        :rtype: ScoreToExplain
        """
        clip_min = 0.01
        clip_max = 0.99
        if self._prediction_type == doctor_constants.REGRESSION:
            score_to_explain = OneDimensionScoreToExplain(pred_df["prediction"].values)
        elif self._prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
            probas_df = pred_df[self._predictor.get_proba_columns()]
            if self._prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                proba_1 = probas_df.values[:, 1]
                score_to_explain_array = log_odds(proba_1, clip_min=clip_min, clip_max=clip_max)
                score_to_explain = OneDimensionScoreToExplain(score_to_explain_array)
            else:
                per_class_score = log_odds(probas_df.values, clip_min=clip_min, clip_max=clip_max)
                score_to_explain = MulticlassScoreToExplain.build_from_best_per_class_score(per_class_score)
        else:
            raise ValueError("Unknown prediction type: {}".format(self._prediction_type))
        return score_to_explain

    def _predict_and_get_transformed_df(self, observations):
        if not self._is_ensemble:
            transformed_a, input_index, empty, unprocessed_df = self._predictor.preprocessing.preprocess(
                observations, with_unprocessed=True)
            if empty:
                raise DroppedBatchException("Whole batch has been dropped by preprocessing")
            predicted_df = self._predictor._predict_preprocessed(transformed_a, unprocessed_df,
                                                                 input_index, True, True)
            return transformed_a, predicted_df
        else:
            return None, self._predictor.get_prediction_dataframe(observations, True, True, False, False)

    def _get_features_to_column_indices_mapping_or_none(self, all_columns):
        """
        :type all_columns: list[str]
        :rtype: np.ndarray or None
        """
        if self._is_ensemble:  # Not supported
            logger.info("Ensemble models do not support features to column indices")
            return None
        generated_features_mapping = self._predictor.preprocessing.pipeline.generated_features_mapping
        features = self._predictor.features
        col_to_col_index = {col: col_index for (col_index, col) in enumerate(all_columns)}
        features_to_column_indices_mapping = np.zeros((len(features),)).astype(int)
        for (feature_index, feature) in enumerate(features):
            try:
                orig_cols = generated_features_mapping.get_origin_columns_from_feature(feature, all_columns)
                if len(orig_cols) != 1:  # No 1-N mapping, not supported
                    logger.info(u"Feature {} comes from multiple columns, no "
                                u"support for feature to column indices".format(safe_unicode_str(feature)))
                    return None
                orig_col = orig_cols[0]
                features_to_column_indices_mapping[feature_index] = col_to_col_index[orig_col]
            except Exception:  # Not supported
                logger.info(u"Feature {} does not support feature to column indices".format(safe_unicode_str(feature)))
                return None
        return features_to_column_indices_mapping

    def _get_preprocessed_features(self):
        if self._is_ensemble:
            return None
        else:
            return self._predictor.features

    class PredictionInformation:
        """ Class to hold all the information about a prediction:
            * the pandas index of the rows predicted
            * the "score_to_explain":
                * for regression, the prediction itself
                * for classification, the log_odd of the probabilities
            * transformed_a: the preprocessed dataset (None if ensembling)
            * features: preprocessed features (None if ensembling)
        """
        def __init__(self, index, score_to_explain, transformed_a, features):
            """
            :param index: index of the data (as in pandas index)
            :type score_to_explain: ScoreToExplain
            :type transformed_a: np.ndarray or scipy.sparse.csr_matrix or None
            :type features: list[str] or None
            """
            self.index = index
            self.score_to_explain = score_to_explain
            self.features = features
            self.transformed_a = transformed_a


class DroppedBatchException(RuntimeError):
    pass
