from sklearn.base import BaseEstimator
from lab.multivariate_model import MultivariateModel
import xgboost as xgb
import numpy as np
import pandas as pd

class ClassXGBoostSurvival(BaseEstimator, MultivariateModel):
    def __init__(self,
                 training_dataset,
                 event_indicator_column,
                 n_estimators=100,
                 learning_rate=0.05,
                 max_depth=6,
                 aft_loss_distribution="normal",
                 aft_loss_distribution_scale=1.2,
                 advanced_parameters=False,
                 reg_lambda=1.0,
                 reg_alpha=0.0,
                 colsample_bytree=1.0,
                 column_labels=None,
                 **kwargs):
        
        self.training_dataset = training_dataset
        self.event_indicator_column = event_indicator_column
        self.column_labels = column_labels
        self.n_estimators = n_estimators
        self.learning_rate = learning_rate
        self.max_depth = max_depth
        self.aft_loss_distribution = aft_loss_distribution
        self.aft_loss_distribution_scale = aft_loss_distribution_scale 
        self.advanced_parameters = advanced_parameters
        self.reg_lambda = reg_lambda
        self.reg_alpha = reg_alpha
        self.colsample_bytree = colsample_bytree

    def _prep_features(self, X_df):
        """Return X with only numeric / category dtypes."""
        Xc = X_df.copy()

        for col in Xc.select_dtypes(include=["datetime64[ns, UTC]", "datetime64[ns]"]):
            Xc[col] = Xc[col].astype("int64") // 10**9

        for col in Xc.select_dtypes(include=["object"]):
            Xc[col] = Xc[col].astype("category")

        return Xc

    def fit(self, X, y):
        self.set_max_duration(y)
        X_df = pd.DataFrame(X, columns=self.column_labels)
        evt   = X_df[self.event_indicator_column]

        y_lower = y
        y_upper = np.where(evt == 1, y, np.inf)

        X_df   = X_df.drop(columns=self.event_indicator_column)
        X_df   = self._prep_features(X_df)

        dtrain = xgb.DMatrix(
            X_df,
            label_lower_bound=y_lower,
            label_upper_bound=y_upper,
            enable_categorical=True
        )

        params = {
            "objective": "survival:aft",
            "eval_metric": "aft-nloglik",
            "aft_loss_distribution": self.aft_loss_distribution,
            "aft_loss_distribution_scale": self.aft_loss_distribution_scale,
            "learning_rate": self.learning_rate,
            "max_depth": self.max_depth,
            "lambda": self.reg_lambda,
            "alpha": self.reg_alpha, 
            "colsample_bytree": self.colsample_bytree,
            "tree_method": "hist",
            "verbosity": 0,
        }
        
        self.fitted_model = xgb.train(params, dtrain, num_boost_round=self.n_estimators)
        self.feature_order_ = X_df.columns.tolist()
        return self

    def predict(self, X):
        X_df = pd.DataFrame(X, columns=self.column_labels)
        X_df = X_df.drop(columns=self.event_indicator_column)
        X_df = self._prep_features(X_df)[self.feature_order_]
        dmat = xgb.DMatrix(X_df, enable_categorical=True)
        predictions = self.fitted_model.predict(dmat)
        predictions = self.process_predictions(predictions)
        return predictions
    
    def process_predictions(self, predictions):
        new_predictions = np.where(predictions > self.max_duration, self.max_duration, predictions)
        return new_predictions
        
    def get_expected_time(self):
        raise NotImplementedError("Expected time prediction is not implemented for XGBoost AFT.")
    
    def get_times_at_probability(self):
        raise NotImplementedError("Time at probability prediction is not implemented for XGBoost AFT.")
    
    def get_probabilities_at_time(self):
        raise NotImplementedError("Probability at time prediction is not implemented for XGBoost AFT.")
    
    def set_prediction_object(self, X):
        # XGBoost AFT directly predicts a time, so this concept of a 'prediction object'
        # like a survival function doesn't directly apply in the same way.
        pass
