from sklearn.base import BaseEstimator
import pandas as pd


class QuantileClusteringModel(BaseEstimator):
    """Quantile Clustering Model"""

    def __init__(self, target_column="value", column_labels=None, n_clusters=5, extend_low=False):
        self.target_column = target_column
        self.n_clusters = n_clusters
        self.column_labels = column_labels
        self.quantiles = None
        self.extend_low = extend_low

    def get_relevant_columns(self, X, columns_name):
        # Retrieve the index of the important column
        column_index = self.column_labels.index(columns_name)
        # Retrieve the corresponding data column
        return X[:, column_index]

    def fit(self, X):
        target = pd.Series(self.get_relevant_columns(X, self.target_column))
        duplicate_bin_flag = False
        try:
            self.quantiles = pd.qcut(target, self.n_clusters).unique().tolist()
            self.quantiles.sort(key=lambda l: l.left)
        except ValueError:
            target = target[target > 0]
            if self.extend_low:
                self.quantiles = pd.qcut(target, self.n_clusters).unique().tolist()
                self.quantiles.sort(key=lambda l: l.left)
            else:
                self.quantiles = pd.qcut(target, self.n_clusters-1).unique().tolist()
                self.quantiles.sort(key=lambda l: l.left)
                self.quantiles[0] = pd.Interval(
                    left=0, right=self.quantiles[0].right, closed="right"
                )
                self.quantiles.insert(0, pd.Interval(left=0.0, right=0, closed="both"))

        self.quantiles[0] = pd.Interval(
            left=0.0, right=self.quantiles[0].right, closed="both"
        )

    def set_column_labels(self, column_labels):
        # in order to preserve the attribute `column_labels` when cloning
        # the estimator, we have declared it as a keyword argument in the
        # `__init__` and set it there
        print(f"set_column_labels: {column_labels}")
        self.column_labels = column_labels

    def get_quantile_number(self, value):
        # if we have N clusters, the max cluster in N-1 (starts at 0)
        # set up default if the value is greater than the largest quantile
        # this makes the last quantiles have an infinite right limit
        value_cluster = self.n_clusters - 1
        for cluster_number, interval in enumerate(self.quantiles):
            if value in interval:
                value_cluster = cluster_number
                break
        return value_cluster

    def predict(self, X):
        target = pd.Series(self.get_relevant_columns(X, self.target_column))

        return target.apply(lambda value: self.get_quantile_number(value)).tolist()

    def fit_predict(self, X):
        self.fit(X)
        return self.predict(X)
