import dataiku
from dataiku.customrecipe import get_recipe_config, get_output_names_for_role
from recipes.io_utils import get_input_output
from recipes.recipe_config_loading import check_and_get_groupby_columns, check_duration_column_parameter, check_event_indicator_column_parameter
from recipes.survival_analysis_statistics import SurvivalAnalysisStatistics

# --- Setup
(input_dataset, output_dataset) = get_input_output()
recipe_config = get_recipe_config()
input_dataset_columns = [column["name"] for column in input_dataset.read_schema()]
check_duration_column_parameter(recipe_config, input_dataset_columns)
check_event_indicator_column_parameter(recipe_config, input_dataset_columns) # is it ok to pass input_dataset here?
groupby_columns = check_and_get_groupby_columns(recipe_config, input_dataset_columns)
duration_column = recipe_config.get('duration_column')
event_indicator_column = recipe_config.get('event_indicator_column')

# --- Run
df = input_dataset.get_dataframe()
statistics = SurvivalAnalysisStatistics()
output_df = statistics.get_output_df(df, duration_column, event_indicator_column, groupby_columns)

# --- Write output

survival_dataset = dataiku.Dataset(get_output_names_for_role('output_dataset')[0])
survival_dataset.write_with_schema(output_df)