from sklearn.base import BaseEstimator
from sksurv.ensemble import RandomSurvivalForest
from lab.multivariate_model import MultivariateModel
import numpy as np
from sksurv.util import Surv
from lab.step_function_utils import get_first_time_below_quantile

class ClassRandomSurvivalForest(BaseEstimator, MultivariateModel):

    def __init__(self,
                 training_dataset, 
                 nb_trees=100, 
                 max_tree_depth=6,
                 min_samples_split=2, 
                 feature_sampling_strategy="sqrt",
                 integer_sampling=1, 
                 fraction_sampling=0.5, 
                 event_indicator_column=None, 
                 prediction_type="predict_time_at_proba",
                 survival_quantile=0.7, 
                 time_for_proba=None, 
                 advanced_parameters=False, 
                 min_samples_leaf=3, 
                 min_weight_fraction_leaf=0,
                 random_state=None,
                 column_labels=None):
         
        self.nb_trees = nb_trees
        self.max_tree_depth = max_tree_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.feature_sampling_strategy = feature_sampling_strategy
        self.integer_sampling = integer_sampling
        self.fraction_sampling = fraction_sampling
        self.advanced_parameters = advanced_parameters
        self.random_state = random_state
        self.survival_functions = None # list of step functions (defined as sksurv.functions.StepFunction)

        MultivariateModel.__init__(self, training_dataset, event_indicator_column, prediction_type, survival_quantile, time_for_proba, column_labels)
    

    def get_expected_time(self):
        raise NotImplementedError
    
    def get_times_at_probability(self):
        # for all survival curves get the first time below quantile (see get_first_time_below_quantile)
        quantile_survival_times = np.array([get_first_time_below_quantile(survival_function, self.survival_quantile) for survival_function in self.survival_functions])
        # quantile_survival_times = np.array([100 for survival_function in self.survival_functions])
        return quantile_survival_times
    
    def get_probabilities_at_time(self):
        # get the probability associated with this time
        probas = np.array([survival_function(self.time_for_proba) for survival_function in self.survival_functions])
        return probas
    
    def set_prediction_object(self, X):
        _, event_indicator_index = self.get_columns(X, [self.event_indicator_column])
        X = np.delete(X, event_indicator_index, axis=1)
        self.survival_functions = self.fitted_model.predict_survival_function(X)
    
    def process_predictions(self, predictions):
        new_predictions = np.where(predictions == np.inf, self.max_duration, predictions)
        return new_predictions

    def get_sksurv_transformed_X_y(self, X, y):
        """
        construct and return
        - transformed X array with deleted event_indicator column
        - transformed y as a structured array containing
            - binary event indicator as first field
            - time of event or time of censoring as second field.
        """
        event_indicator_values, event_indicator_index = self.get_columns(X, [self.event_indicator_column])
        X = np.delete(X, event_indicator_index, axis=1)

        # flat column_values
        flat_event_indicator_values = [x for row in event_indicator_values for x in row]
        times = y.tolist()
        y = Surv.from_arrays(flat_event_indicator_values, times)
        return X, y

    def get_processed_feature_sampling_strategy(self):
        if self.feature_sampling_strategy == 'all':
            return None
        if self.feature_sampling_strategy == 'integer':
            return self.integer_sampling
        elif self.feature_sampling_strategy == 'fraction':
            return self.fraction_sampling
        else:
            return self.feature_sampling_strategy

    def fit(self, X, y):
        self.check_event_indicator_values(X)
        self.set_max_duration(y)

        feature_sampling_strategy_processed = self.get_processed_feature_sampling_strategy()
        
        model = RandomSurvivalForest(n_estimators=self.nb_trees, 
                                     max_depth=self.max_tree_depth, 
                                     min_samples_split=self.min_samples_split,
                                     min_samples_leaf=self.min_samples_leaf, 
                                     min_weight_fraction_leaf=self.min_weight_fraction_leaf,
                                     max_features=feature_sampling_strategy_processed,
                                     random_state=self.random_state)
        transfo_X, transfo_y = self.get_sksurv_transformed_X_y(X, y)
        model.fit(transfo_X, transfo_y)
        self.fitted_model = model

    
    
    
    

    
    