import pandas as pd
from lifelines import KaplanMeierFitter, WeibullFitter, NelsonAalenFitter, PiecewiseExponentialFitter
from bisect import bisect_left
from recipes.base_recipe import BaseRecipe

class UnivariateModelParams():

    def __init__(self,
                 model_type="kaplan_meier", 
                 function_to_output="survival_function",
                 confidence_interval_percentage=95, 
                 breakpoints=None):

        self.model_type=model_type
        self.function_to_output=function_to_output
        self.confidence_interval_percentage=confidence_interval_percentage
        self.breakpoints=breakpoints

    def check(self):
        if self.confidence_interval_percentage < 0 or self.confidence_interval_percentage > 100:
            raise ValueError("Confidence interval level percentage must be between 0 and 100")
        
        
class UnivariateModel(BaseRecipe):
    def __init__(self, params=None):
        
        if params is None:
            raise ValueError('UnivariateModelParams not specified.')
        self.params = params
        self.params.check()
        self.fitted_model = None
        self.confidence_interval_alpha = self.get_confidence_interval_alpha()

    def get_confidence_interval_alpha(self):
        alpha = (1 - self.params.confidence_interval_percentage / 100)
        return alpha

    def get_function_and_confidence_intervals(self):
        function = None
        confidence_intervals = None
        if self.params.function_to_output == "survival_function":
            function = self.fitted_model.survival_function_
            confidence_intervals = self.fitted_model.confidence_interval_survival_function_
        elif self.params.function_to_output == "cumulative_density":
            function = self.fitted_model.cumulative_density_
            confidence_intervals = self.fitted_model.confidence_interval_cumulative_density_
            #calculated as 1 - cumulative_density_survival_function, so needs to be reversed
            confidence_intervals.columns = reversed(confidence_intervals.columns)
        elif self.params.function_to_output == "density":
            function = self.fitted_model.density_
            confidence_intervals = self.fitted_model.confidence_interval_density_
        elif self.params.function_to_output == "hazard":
            function = self.fitted_model.hazard_
            confidence_intervals = self.fitted_model.confidence_interval_hazard_
        elif self.params.function_to_output == "cumulative_hazard":
            function = self.fitted_model.cumulative_hazard_
            confidence_intervals = self.fitted_model.confidence_interval_cumulative_hazard_
        else:
            raise Exception("The selected function to output could not be found")
        
        # round the dataframes
        if self.params.function_to_output == "density":
            # densities can be very small
            number_decimals = 7
        else:
            number_decimals = 3
        function = function.round(number_decimals)
        confidence_intervals = confidence_intervals.round(number_decimals)
        return function, confidence_intervals
    
    def get_model(self):
        """
        get the unfitted lifelines model object according to user input 
        """
        model = None
        if self.params.model_type == "kaplan_meier":
            model = KaplanMeierFitter(alpha=self.confidence_interval_alpha)
        elif self.params.model_type == "weibull":
            model = WeibullFitter(alpha=self.confidence_interval_alpha)
        elif self.params.model_type == "nelson_aalen":
            model = NelsonAalenFitter(alpha=self.confidence_interval_alpha, nelson_aalen_smoothing=False)
        elif self.params.model_type == "piecewise_exponential":
            model = PiecewiseExponentialFitter(self.params.breakpoints, alpha=self.confidence_interval_alpha)
        else:
            raise Exception("Model type could not be found")
        return model
        
    
    def get_group_df(self, group, group_label, duration_column, event_indicator_column):
        """
        fit the univariate model on the given group
        """
        
        durations = group[duration_column]
        event_indicator = group[event_indicator_column]
        
        model = self.get_model()
        label = self.params.function_to_output

        self.fitted_model = model.fit(durations=durations, event_observed=event_indicator, label=label)
        
        # get survival probabilites and confidence intervals
        function, confidence_intervals = self.get_function_and_confidence_intervals()
        
        group_df = pd.concat([function, confidence_intervals], axis=1).reset_index()
        group_df[BaseRecipe.GROUP_BY_COLUMN_NAME] = group_label

        if self.params.model_type == "weibull":
            rho = round(self.fitted_model.rho_, 1)
            failure_type = self.get_weibull_failure_type(rho)
            group_df["failure_type"] = failure_type
            group_df["Weibull Shape Parameter"] = round(self.fitted_model.rho_,1)
        
        elif self.params.model_type == "piecewise_exponential":
            group_df["fitted parameter"] = group_df["index"].apply(lambda time: getattr(self.fitted_model, "lambda_%d_" % bisect_left(self.params.breakpoints, time)))
            group_df["fitted parameter"] = group_df["fitted parameter"].round(3)
        
        return group_df

    def get_weibull_failure_type(self, rho):
        failure_type = None
        if rho < 1:
            failure_type = "Early-life failures"
        elif rho == 1:
            failure_type = "Random failures"
        elif rho > 1:
            failure_type = "Wear-out failures"
        return failure_type


    def get_output_df(self, df, duration_column, event_indicator_column, groupby_columns):
    
        self.check_duration_column(df, duration_column)
        self.check_event_indicator_column(df, event_indicator_column)
        self.get_processed_data(df, event_indicator_column)

        fitted_models_on_groups_list = []
        
        # if no groupby columns are selected, the output is one curve for the whole dataset
        if len(groupby_columns) == 0:
            fitted_model_on_df = self.get_group_df(df, BaseRecipe.NO_GROUP_NAME, duration_column, event_indicator_column)
            fitted_models_on_groups_list.append(fitted_model_on_df)
        else:
            groups = df.groupby(groupby_columns, dropna=False)
            for groupby_values, group_df in groups:
                # if less than two rows in the group, we skip the estimation
                if len(group_df) < 2:
                    continue
                group_label = self.get_group_label(groupby_columns, groupby_values)
                fitted_model_on_group_df = self.get_group_df(group_df, group_label, duration_column, event_indicator_column)
                fitted_models_on_groups_list.append(fitted_model_on_group_df)
        
        output_df = pd.concat(fitted_models_on_groups_list, axis=0, ignore_index=True)
        output_df.rename(columns={"index" : duration_column}, inplace=True)
        # round timeline (most useful for weibull: interpolation can create time values such as 20.33333333)
        output_df[duration_column] = output_df[duration_column].round(3)
        return output_df