from sklearn.base import BaseEstimator
from pymc_marketing import clv
import pandas as pd
import arviz as az


class LifetimesCLVModel(BaseEstimator):
    """Lifetimes CLV prediction"""

    def __init__(
        self,
        frequency_col="frequency",
        t_col="customer_age",
        recency_col="recency",
        monetary_col="monetary_value",
        column_labels=None,
        forward_window=12,
        discount_rate=0.01,
        lifetime_activation=True,
    ):
        self.frequency_col = frequency_col
        self.t_col = t_col
        self.recency_col = recency_col
        self.monetary_col = monetary_col

        self.column_labels = column_labels
        self.forward_window = forward_window  # month
        self.discount_rate = discount_rate  # monthly discount rate ~ 12.7% annually
        self.lifetime_activation = lifetime_activation
        
        self.bgf = None
        self.ggf = None

    def get_relevant_columns(self, X, columns_name):
        # Retrieve the index of the important column
        column_index = self.column_labels.index(columns_name)
        # Retrieve the corresponding data column
        return X[:, column_index]

    def build_input_df(self, X):
        data =  pd.DataFrame(
            {
                "frequency": self.get_relevant_columns(X, self.frequency_col),
                "T": self.get_relevant_columns(X, self.t_col),
                "recency": self.get_relevant_columns(X, self.recency_col),
                "mean_transaction_value": self.get_relevant_columns(X, self.monetary_col),
            },
            dtype="float"
        )
        data = data.reset_index()
        data = data.rename(columns={"index":"customer_id"})
        return data

    def fit(self, X, y):
        if self.lifetime_activation:
            data = self.build_input_df(X)
            nonzero_data = data[data["frequency"] > 0]
            
            self.bgf = clv.BetaGeoModel(data)
            self.bgf.fit(fit_method="map")
            self.ggf  = clv.GammaGammaModel(nonzero_data)
            self.ggf.fit()

    def set_column_labels(self, column_labels):
        # in order to preserve the attribute `column_labels` when cloning
        # the estimator, we have declared it as a keyword argument in the
        # `__init__` and set it there
        print(f"set_column_labels: {column_labels}")
        self.column_labels = column_labels

    def predict(self, X):
        data = self.build_input_df(X)
        if self.lifetime_activation:
            expected_clv = self.ggf.expected_customer_lifetime_value(
                transaction_model=self.bgf,
                customer_id = data["customer_id"],
                mean_transaction_value=data["mean_transaction_value"],
                frequency=data["frequency"],
                recency=data["recency"],
                T=data["T"],
                time=self.forward_window,
                freq="M",
                discount_rate=self.discount_rate
            )
            stats = az.summary(expected_clv, kind="stats")
            data["predicted_clv"] = stats["mean"].values
        else:
            data["predicted_clv"] = 0
        return data["predicted_clv"].values

    def fit_predict(self, X, y):
        self.fit(X, y)
        return self.predict(X)
