import numpy as np
from numpy.core.numeric import asarray
from sklearn.ensemble import RandomTreesEmbedding


class PlausibilityScorer(object):
    MAX_NB_ROWS_TO_FIT = 1000  # Limit the training set of the plausibility model because fitting takes a lot of time
    PROBABILITY_STEP = 0.05  # Returned plausibilities will be a multiple of this number
    N_ESTIMATORS = 50  # To fit the RandomTreesEmbedding

    def __init__(self):
        self.X_train_sparse_embedding = None
        self.probability_levels = None
        self.x_train_correlation_quantiles = None
        self.random_trees = None

    def fit(self, X_train):
        """
        Fit model for plausibility computation.
        :param X_train: training dataset - must be preprocessed (ie. sklearn-compatible)
        """
        self.random_trees = RandomTreesEmbedding(n_estimators=self.N_ESTIMATORS, random_state=42)
        self.random_trees.fit(X_train[:self.MAX_NB_ROWS_TO_FIT])

        self.X_train_sparse_embedding = self.random_trees.transform(X_train[:self.MAX_NB_ROWS_TO_FIT])
        prod = self.X_train_sparse_embedding.dot(self.X_train_sparse_embedding.T)
        x_train_correlation = prod.mean(axis=1)

        self.probability_levels = np.arange(0.0, 1.0, self.PROBABILITY_STEP)

        # force asarray(x_train_correlation) because numpy 1.22+ quantile() doesn't handle matrix of shape (X,1) anymore
        self.x_train_correlation_quantiles = np.quantile(asarray(x_train_correlation), self.probability_levels)
        self.x_train_correlation_quantiles.sort()

    def compute_plausibility(self, X):
        """
        Compute plausibility scores in terms of similarity with the training set, in the space of Random Trees Embeddings.
        The plausibility is the quantile corresponding to the embeddings similarity and is a value in [0., 1.].
        Usually, one can consider a record to be plausible if its plausibility is above 0.05.
        The individual plausibility scores are averaged across the batch.
        :param X: records to score.
        """
        X_sparse_embedding = self.random_trees.transform(X)
        prod = X_sparse_embedding.dot(self.X_train_sparse_embedding.T)
        x_correlation_to_train = prod.mean(axis=1)
        x_correlation_to_train_bins = np.digitize(x_correlation_to_train, bins=self.x_train_correlation_quantiles)

        # might happen to have values at the top tail, outside training distribution. How to interpret those?
        x_correlation_to_train_bins = np.clip(x_correlation_to_train_bins, 0, len(self.probability_levels)-1)

        return np.atleast_1d(np.squeeze(self.probability_levels[x_correlation_to_train_bins]))
