import numpy as np
from abc import ABC, abstractmethod
class MultivariateModel(ABC):
    DURATION_COLUMN_NAME = "DURATION_COLUMN"
    def __init__(self, training_dataset, event_indicator_column, prediction_type, survival_quantile, time_for_proba, column_labels):
        self.training_dataset = training_dataset
        self.event_indicator_column = event_indicator_column
        self.prediction_type = prediction_type
        self.survival_quantile = survival_quantile
        self.time_for_proba = time_for_proba
        self.fitted_model = None
        self.column_labels = column_labels
        self.max_duration = None

    @abstractmethod
    def get_probabilities_at_time(self):
        """
        Predict the probability of survival at given time
        """
        pass
    
    @abstractmethod
    def get_times_at_probability(self):
        """
        Predict the first time at which the survival function goes below given probability 
        """
        pass
    
    @abstractmethod
    def get_expected_time(self):
        """
        Predict the average expected time of survival
        """
        pass

    def check_event_indicator_values(self, X):
        """
        check that event indicator values are all 0s and 1s
        """
        event_indicator_values, _ = self.get_columns(X, [self.event_indicator_column])
        if not np.all((event_indicator_values[:] == 0) | (event_indicator_values[:] == 1)):
            raise ValueError("Event indicator column has values different from 0 and 1. Please check that 'No rescaling' is selected for that column in the feature handling tab")
      
    
    @abstractmethod
    def set_prediction_object(self, X):
        """
        - for coxph, set the prepared covariates used for prediction
        - for rsf, set the the predicted survival functions
        """
        pass

    @abstractmethod
    def process_predictions(self, predictions):
        """
        replace np.inf by max value of time in dataset
        """
        pass

    def predict(self, X):
        self.set_prediction_object(X)
        if self.prediction_type == "predict_expected_time":
            predictions = self.get_expected_time()
        elif self.prediction_type == "predict_time_at_proba":
            predictions = self.get_times_at_probability()
        elif self.prediction_type == "predict_proba_at_time":
            predictions = self.get_probabilities_at_time()
        else:
            raise ValueError("predict type not found")
        predictions = self.process_predictions(predictions)
        return predictions

    @abstractmethod
    def fit(self, X, y):
        """
        fit model
        """
        pass
    
    def set_max_duration(self, y):
        """
        retrieve the maximum value of the duration column in our dataset
        """
        self.max_duration = np.amax(y)
        
    def set_column_labels(self, column_labels):
        # in order to preserve the attribute `column_labels` when cloning
        # the estimator, we have declared it as a keyword argument in the
        # `__init__` and set it there
        self.column_labels = column_labels

    def get_columns(self, X, important_columns):
        """ 
        returns an array of values specified by column names
        """
        if important_columns is None:
            important_columns = []
        if len(important_columns) == 0:
            column_values = []
            column_indices = []
        else:
            for important_column in important_columns:
                if important_column not in self.column_labels:
                    raise ValueError(
                        f'The column name provided: [{important_column}], is not present in the list of columns from the dataset. Please check that the column is selected in feature handling.')
            column_indices = [self.column_labels.index(important_column) for important_column in important_columns]
            column_values = X[:, column_indices]
        
        return column_values, column_indices