"""FeaturesDistributionHandler and features_distribution compute functions

Abstract: These classes handle the task of computing, saving and loading the
    distributions. At the end of the training, the FeaturesDistributionHandler
    will save all computations in a json file so that they can be retrieved
    faster afterward.

Note: For numerical features, the distributions can be computed using an
    histogram (ie. constant range size for all bins) or using quantiles (ie.
"""

import logging
import numpy as np
from abc import ABCMeta
from abc import abstractmethod

import pandas as pd
from scipy.stats.mstats_basic import mquantiles
from six import add_metaclass

from dataiku.core import doctor_constants
from dataiku.doctor.prediction.common import weighted_quantiles
from dataiku.doctor.utils.split import input_columns

logger = logging.getLogger(__name__)

MAX_NB_MODALITIES_FOR_CATEGORICAL = 1000
NB_BINS_FOR_NUMERIC_HISTOGRAMS = 50
NB_TOP_VALUES = 50  # for numeric columns, save the distribution of the top 50 values


class FeaturesDistributionHandler(object):
    """
    Format of the features_distribution dict:
    {
        [numerical_feature]: {
            "type": doctor_constants.NUMERIC,
            "nbDistinct": int,
            "count": float or int,
            "missingCount": float or int,
            "histogram": {
                "scale": np.ndarray,
                "distribution": np.ndarray
            },
            "quantiles": {
                "scale": np.ndarray,
                "distribution": np.ndarray
            },
            "topValues": [{
                "value": float or int,
                "count": float or int
            }]
        },
        [non-numeric_feature]: {
            "type": doctor_constants.CATEGORY | doctor_constants.TEXT | ...,
            "nbDistinct": int,
            "count": float or int,
            "missingCount": float or int,
            "values": {
                "scale": np.ndarray,
                "distribution": np.ndarray
            }
        },
        ...
    }
    """

    FEATURES_DISTRIBUTION_FILENAME = "features_distribution.json"

    def __init__(self, model_folder_context=None):
        """
        :type model_folder_context: dataiku.base.folder_context.FolderContext
        """
        self.model_folder_context = model_folder_context

    def compute_all(self, df, per_feature, save, sample_weight=None):
        """Compute the distributions
        :param pd.DataFrame df: dataset
        :param dict per_feature: info relative to the features of the dataset
        :param np.ndarray or None sample_weight: sample weight values (same number of rows as dataset)
        :param save: whether to save the features distribution
        :type save: bool
        :return: features distribution
        :rtype: {[feature_name]: NumericFeatureDistribution or CategoricalFeatureDistribution}
        """
        logger.info("Computing features distribution")
        features_distribution = {}
        # Remove NaNs from sample_weight
        sample_weight = np.nan_to_num(sample_weight)

        for name in input_columns(per_feature):
            feature_type = per_feature[name]["type"]
            feature_distribution_computer = FeatureDistributionComputer.build_from_feature_type(feature_type)
            feature_distribution_computer.compute(df[name], sample_weight)
            features_distribution[name] = feature_distribution_computer
        if save:
            self._save(features_distribution)
        return features_distribution

    def has_saved_features_distribution(self):
        """Tell if the distributions have already been computed and saved
        :return: True if the JSON file is accessible, False if not
        :rtype: bool
        """
        return ((self.model_folder_context is not None)
                and self.model_folder_context.isfile(self.FEATURES_DISTRIBUTION_FILENAME))

    def load(self):
        """Retrieve the distributions from the JSON file
        :return: features distribution
        :rtype: {[feature_name]: NumericFeatureDistribution or CategoricalFeatureDistribution}
        """
        if not self.has_saved_features_distribution():
            raise RuntimeError("Could not load the features distribution")
        logger.info("Retrieving features distribution from file")
        features_distribution_dict = self.model_folder_context.read_json(self.FEATURES_DISTRIBUTION_FILENAME)["featuresDistribution"]
        features_distribution = {}
        for elem_name, elem in features_distribution_dict.items():
            features_distribution[elem_name] = FeatureDistributionComputer.build_from_dict(elem)
        return features_distribution

    def _save(self, features_distribution):
        """Save the distributions in the JSON file
        :param features_distribution: features distribution
        :type features_distribution: {[feature_name]: NumericFeatureDistribution or CategoricalFeatureDistribution}
        """
        if self.model_folder_context is None:
            logger.warning("Could not find folder to save features distribution")
            return
        logger.info("Saving features distribution")
        features_distribution_dict = {}
        for elem_name, elem in features_distribution.items():
            features_distribution_dict[elem_name] = elem.to_dict()

        self.model_folder_context.write_json(self.FEATURES_DISTRIBUTION_FILENAME,
                                             {"featuresDistribution": features_distribution_dict})


@add_metaclass(ABCMeta)
class FeatureDistributionComputer(object):
    @abstractmethod
    def __init__(self, feature_type=None):
        self.feature_type = feature_type
        self.nb_distinct = None
        self.count = None
        self.missing_count = None
        # Warning: the two following attributes are not serialized. Do not use them in getters.
        self.series_without_nans = None
        self.sample_weight_aligned_on_series = None  # doesn't contain NaNs and same length than self.series_without_nans
        self.uniques = None

    @abstractmethod
    def compute(self, series, sample_weight=None):
        """
        :param pd.Series series: may contain NaNs
        :param np.ndarray sample_weight: must not contain any NaNs
        """
        self.count = series.shape[0] if sample_weight is None else sample_weight.sum()
        self.missing_count = self._get_weighted_nan_count(series, sample_weight)
        self.series_without_nans = series[series.notna()]
        self.sample_weight_aligned_on_series = None if sample_weight is None else sample_weight[series.notna()]
        self.uniques = pd.unique(self.series_without_nans)
        self.nb_distinct = len(self.uniques) + (1 if self.missing_count > 0 else 0)

    @abstractmethod
    def to_dict(self):
        return {
            "type": self.feature_type,
            "nbDistinct": self.nb_distinct,
            "count": self.count,
            "missingCount": self.missing_count
        }

    @abstractmethod
    def from_dict(self, d):
        self.nb_distinct = d["nbDistinct"]
        self.count = d["count"]
        self.missing_count = d["missingCount"]

    @staticmethod
    def build_from_feature_type(feature_type):
        if feature_type == doctor_constants.NUMERIC:
            return NumericFeatureDistributionComputer()
        else:
            return CategoricalFeatureDistributionComputer(feature_type)

    @staticmethod
    def build_from_dict(d):
        feature_distribution_computer = FeatureDistributionComputer.build_from_feature_type(d["type"])
        feature_distribution_computer.from_dict(d)
        return feature_distribution_computer

    @staticmethod
    def add_nan_to_np_array(array):
        # Need to go through Series because numpy array dtype might not accept NaN (e.g. if it's <U2)
        return pd.concat([pd.Series(array), pd.Series(np.nan)]).values

    def add_nan_to_distribution(self, scale, distribution):
        new_scale = self.add_nan_to_np_array(scale)

        new_distribution = distribution * (self.count - self.missing_count)
        new_distribution = np.append(new_distribution, self.missing_count)
        new_distribution /= new_distribution.sum()

        return new_scale, new_distribution

    @staticmethod
    def _get_weighted_nan_count(column, sample_weight):
        """Compute weighted number of nans for a given feature
        :param pd.Series column: column of the dataframe
        :param np.ndarray or None sample_weight: sample weight values, cannot contain NaNs
        :return: weighted number NaNs
        :rtype: flaot
        """
        if sample_weight is None:
            sample_weight = np.ones(column.shape)
        return np.sum(sample_weight[pd.isna(column.values)])

    @staticmethod
    def _get_weighted_counts(series_without_nans, sample_weight_aligned_on_series, max_nb_modalities):
        """Compute weighted counts for a given feature
        :param pd.Series series_without_nans: series representing a column without nan values
        :param np.ndarray or None sample_weight_aligned_on_series: sample weight values, cannot contain NaNs
        :param int or None max_nb_modalities: limits the size of the scale
        :return: (scale, weighted number of samples for each value)
        :rtype: (np.ndarray, np.ndarray)
        """
        if sample_weight_aligned_on_series is None:
            sample_weight_aligned_on_series = np.ones(series_without_nans.shape)

        weighted_count_series = pd.DataFrame({"data": series_without_nans, "weight": sample_weight_aligned_on_series}).groupby("data")["weight"].sum()

        # keep the most frequent and sort
        if max_nb_modalities is not None:
            weighted_count_series = weighted_count_series.nlargest(max_nb_modalities)
        else:
            weighted_count_series = weighted_count_series.sort_values(ascending=False)

        scale = weighted_count_series.index.values
        weighted_counts = weighted_count_series.values.astype(float)
        return scale, weighted_counts


class NumericFeatureDistributionComputer(FeatureDistributionComputer):

    def __init__(self):
        super(NumericFeatureDistributionComputer, self).__init__(feature_type=doctor_constants.NUMERIC)
        self.histogram_scale = None
        self.histogram_distribution = None

        self.quantiles_scale = None
        self.quantiles_distribution = None

        self.top_values_scale = None
        self.top_values_counts = None

    def compute(self, series, sample_weight=None):
        super(NumericFeatureDistributionComputer, self).compute(series, sample_weight)
        self.histogram_scale, self.histogram_distribution = self.compute_histogram(self.series_without_nans, self.sample_weight_aligned_on_series)
        self.quantiles_scale, self.quantiles_distribution = self.compute_quantiles(self.series_without_nans, self.sample_weight_aligned_on_series)
        self.top_values_scale, self.top_values_counts = self._get_weighted_counts(self.series_without_nans, self.sample_weight_aligned_on_series, NB_TOP_VALUES)

    def get_quantiles_with_nans(self):
        """Retrieve quantile-based scale and the corresponding distribution.
            Scale is:
                - (quantile 0% + quantile 10%) / 2
                - (quantile 10% + quantile 20%) / 2
                - ...
                - (quantile 90% + quantile 100%) / 2
                - NaN
            Distribution is
                - (# of elements between quantile 0% and quantile 10%) / count
                - (# of elements between quantile 10% and quantile 10%) / count
                - ...
                - (# of elements between quantile 90% and quantile 100%) / count
                - missingCount / count
        :return: (scale including NaN, proportion of data in the corresponding group)
        :rtype: (np.ndarray, np.ndarray)
        """
        if self.missing_count == 0:
            return self.quantiles_scale, self.quantiles_distribution
        else:
            return self.add_nan_to_distribution(self.quantiles_scale, self.quantiles_distribution)

    def get_top_values_with_nans(self):
        """Retrieve most common values and their number of occurrences.
        :return: top values
        :rtype: [{value: float, count: float}]
        """
        if self.missing_count == 0 or (len(self.top_values_counts) > 0 and self.missing_count < self.top_values_counts[-1]):
            # missing value is not in the top values
            return self.top_values_scale, self.top_values_counts
        else:
            new_top_values = self.add_nan_to_np_array(self.top_values_scale)
            new_top_values_count = np.append(self.top_values_counts, self.missing_count)
            new_order = np.argsort(new_top_values_count)[::-1][:NB_TOP_VALUES]
            return new_top_values[new_order], new_top_values_count[new_order]

    def get_histograms(self):
        """Retrieve the scale and distributions as a histogram with evenly spaced bins
        :return: (scale, proportion of data in the corresponding bin)
        :rtype: (np.ndarray, np.ndarray)
        """
        return self.histogram_scale, self.histogram_distribution

    def to_dict(self):
        base_dict = super(NumericFeatureDistributionComputer, self).to_dict()
        base_dict["histogram"] = {
            "scale": self.histogram_scale,
            "distribution": self.histogram_distribution
        }
        base_dict["quantiles"] = {
            "scale": self.quantiles_scale,
            "distribution": self.quantiles_distribution
        }
        base_dict["topValues"] = [
            {"value": value, "count": count} for value, count in zip(self.top_values_scale,
                                                                     self.top_values_counts)
        ]
        return base_dict

    def from_dict(self, d):
        super(NumericFeatureDistributionComputer, self).from_dict(d)

        self.histogram_scale = np.array(d["histogram"]["scale"])
        self.histogram_distribution = np.array(d["histogram"]["distribution"])

        self.quantiles_scale = np.array(d["quantiles"]["scale"])
        self.quantiles_distribution = np.array(d["quantiles"]["distribution"])

        self.top_values_scale = np.array([x["value"] for x in d["topValues"]])
        self.top_values_counts = np.array([x["count"] for x in d["topValues"]])

    @staticmethod
    def compute_quantiles(series_without_nans, sample_weight=None):
        """Compute quantiles and corresponding distribution for a given numerical feature
        :param pd.Series series_without_nans: series representing a column without nan values
        :param np.ndarray or None sample_weight: sample weight values (same number of rows as dataset)
        :return: (bins, proportion of data in the corresponding bin)
        :rtype: (np.ndarray, np.ndarray)
        """
        bin_edges = NumericFeatureDistributionComputer._get_binned_not_nan_values_quantiles(series_without_nans.values, sample_weight)
        # center the bins
        scale = np.asarray([bin_edges[i] + (bin_edges[i + 1] - bin_edges[i]) / 2 for i in range(len(bin_edges) - 1)])

        distribution, _ = np.histogram(series_without_nans.values, bins=bin_edges, density=False, weights=sample_weight)

        distribution = distribution.astype(float) / np.sum(distribution)
        return scale, distribution

    @staticmethod
    def compute_histogram(column, sample_weight_aligned_on_series=None):
        """Compute scale and histogram distribution for a given numerical feature
        :param pd.Series column: column of the dataframe
        :param np.ndarray or None sample_weight_aligned_on_series: sample weight values (same number of rows as dataset)
        :return: (bins, proportion of data in the corresponding bin)
        :rtype: (np.ndarray, np.ndarray)
        """
        bin_edges = np.linspace(column.min(), column.max(), num=NB_BINS_FOR_NUMERIC_HISTOGRAMS)

        distribution, _ = np.histogram(column.values, bins=bin_edges, density=False, weights=sample_weight_aligned_on_series)
        distribution = distribution.astype(float) / np.sum(distribution)

        return bin_edges, distribution

    @staticmethod
    def _get_binned_not_nan_values_quantiles(values, sample_weight):
        """Compute quantiles for a numeric value (which does not contains nan values)
        :param np.ndarray values: the non nan values to compute bins on
        :param np.ndarray or None sample_weight: sample weight values (same number of rows as dataset)
        :return: list of bins
        :rtype: np.ndarray
        """
        quantiles_to_compute = np.arange(0.0, 1.1, 0.1)
        if sample_weight is None:
            # np.sort shouldn't be necessary but works around a microbug leading to non-monotonous quantiles.
            # quantiles could include [..., a, b, a, ...] with b < a at the 15 or 16th decimal place,
            # and bins must increase monotonically
            return np.sort(mquantiles(values, prob=quantiles_to_compute))
        else:
            sort_index = values.argsort()
            return weighted_quantiles(values[sort_index], sample_weight[sort_index], quantiles_to_compute)


class CategoricalFeatureDistributionComputer(FeatureDistributionComputer):

    def __init__(self, feature_type=None, max_nb_modalities=MAX_NB_MODALITIES_FOR_CATEGORICAL):
        """
        :param CATEGORY or TEXT etc. feature_type: the type of the feature
        :param int max_nb_modalities: maximum number of distinct values to store
        """
        super(CategoricalFeatureDistributionComputer, self).__init__(feature_type)
        self.scale = None
        self.distribution = None
        # Warning: the two following attribute is not serialized. Do not use it in getters.
        self.max_nb_modalities = max_nb_modalities

    def compute(self, series, sample_weight=None):
        super(CategoricalFeatureDistributionComputer, self).compute(series, sample_weight)
        self.scale, self.distribution = self._compute_values_for_categorical_col()

    def to_dict(self):
        base_dict = super(CategoricalFeatureDistributionComputer, self).to_dict()
        base_dict["values"] = {
            "scale": self.scale,
            "distribution": self.distribution
        }
        return base_dict

    def from_dict(self, d):
        super(CategoricalFeatureDistributionComputer, self).from_dict(d)
        self.scale = np.array(d["values"]["scale"])
        self.distribution = np.array(d["values"]["distribution"])
    
    def get_values_with_nans(self):
        """Returns scale and distribution for a given categorical/vector/text feature.
        :return: (scale including NaN, proportion of data in the corresponding group)
        :rtype: (np.ndarray, np.ndarray)
        """
        if self.missing_count == 0:
            return self.scale, self.distribution
        else:
            return self.add_nan_to_distribution(self.scale, self.distribution)

    def _compute_values_for_categorical_col(self):
        """Compute scale and distribution for a given categorical/vector/text feature.
        :return: (scale, proportion of data in the corresponding group)
        :rtype: (np.ndarray, np.ndarray)
        """
        scale, weighted_counts = self._get_weighted_counts(self.series_without_nans, self.sample_weight_aligned_on_series,
                                                           max_nb_modalities=self.max_nb_modalities)
        distribution = weighted_counts / weighted_counts.sum()
        return scale, distribution
