# -------------------------------------------------------------------------------- NOTEBOOK-CELL: MARKDOWN
# # Covariates analysis

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
from dataiku import insights
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
from sklearn.model_selection import train_test_split
from sklearn.metrics import brier_score_loss
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter
import warnings
import plotly
import plotly.graph_objects as go
import plotly.subplots as sp

warnings.filterwarnings("ignore")

# Read recipe inputs
model_1 = dataiku.Model("FjRNGzZC")
pred_1 = model_1.get_predictor()
cph = pred_1._clf.fitted_model
df_covariates = pd.DataFrame(cph.hazard_ratios_).reset_index()

# Read recipe inputs
df = dataiku.Dataset("periods_with_covariates_prepared").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def interpret_p_value(p_value):
    if p_value < 0.01:
        return "Highly significant"
    elif p_value < 0.05:
        return "Significant"
    elif p_value < 0.3:
        return "Moderately significant"
    else:
        return "Not significant"

df_covariates =  cph.summary.reset_index()[["covariate", "coef", "exp(coef)", "p"]]
df_covariates['statistical_validity'] = df_covariates['p'].apply(interpret_p_value)
df_covariates['risk_multiplier'] = df_covariates['exp(coef)'].round(2)
df_covariates['reduces_risk'] = df_covariates['risk_multiplier'] < 1

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Get unique original column names
from collections import defaultdict

modalities = defaultdict(list)
numericals = set()
categoricals = set()
unique_orig_cols = []
modalities_col = []

for col in cph.params_.index.to_list():
    if col[:6] == 'dummy:':
        categoricals.add(col.split(':')[1])
        modalities[col.split(':')[1]].append(col.split(':')[2])
        unique_orig_cols.append(col.split(':')[1])
        modalities_col.append(col.split(':')[2])
    else:
        numericals.add(col)
        unique_orig_cols.append(col)
        modalities_col.append(col)

def get_dropped_modality(variable):
    for modality in df[variable].unique():
        if modality not in modalities[variable]:
            return modality

num_categorical = len(categoricals)
num_numerical = len(numericals)

df_covariates['original_column'] = unique_orig_cols
unique_orig_cols = list(set(unique_orig_cols))
df_covariates['covariate'] = modalities_col

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Custom function to transform the risk multipliers
def transform_multiplier(risk_multiplier):
    if risk_multiplier >= 1:
        return np.log10(risk_multiplier)
    else:
        return -np.log10(1/risk_multiplier)

# Apply the transformation to the data
df_covariates['transformed_multiplier'] = df_covariates['risk_multiplier'].apply(transform_multiplier)
df_covariates = df_covariates.reindex(df_covariates['transformed_multiplier'].abs().sort_values(ascending=True).index)
df_covariates['transformed_multiplier'] = df_covariates['transformed_multiplier'].replace(0, 0.0001)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Check if there will be a subplot for numerical variables and add to the count
additional_subplot = 1 if num_numerical > 0 else 0

# Calculate number of rows needed for the subplot
num_of_rows = int(np.ceil((num_categorical + additional_subplot) / 2))

# To make sure that we keep a reasonable spacing between subplots
def calculate_vertical_spacing(num_of_rows):
    slope = (0.01 - 0.05) / (15 - 3)
    y_intercept = 0.05 - slope * 3

    vertical_spacing = slope * num_of_rows + y_intercept
    vertical_spacing = max(0, min(1, vertical_spacing))
    return vertical_spacing

# Create subplot grid with reduced vertical_spacing for better space usage
fig = sp.make_subplots(rows=num_of_rows, cols=2, subplot_titles=unique_orig_cols, vertical_spacing=calculate_vertical_spacing(num_of_rows))

# Function to create the plot for each original column
def create_plot_per_orig_col(df_filtered, orig_col, row, col):
    color_mapping = {
        "Highly significant": "#3985ac",
        "Significant": "#88b5cd",
        "Moderately significant": "#bbd3dc",
        "Not significant": "#bbbbbb",
    }
    colors = df_filtered['statistical_validity'].map(color_mapping)

    # Truncate covariate names to a certain length (let's say 15 characters) and add "..." if truncated
    truncated_covariates = [covariate[:15] + '...' if len(covariate) > 15 else covariate for covariate in df_filtered['covariate']]

    # Set the title for the subplot depending on categorical vs numerical
    if orig_col == 'All Numerical Variables':
        title_text = 'Relative Risk Multipliers for <b>Numerical Columns</b>'
        fig.layout.annotations[(row-1)*2 + col - 1]['text'] = title_text
        hovertemplate = '<b>Original Column:</b> %{customdata[0]}<br><b>Covariate:</b> %{customdata[1]}<br><b>Risk Multiplier:</b> %{text}<br><b>Statistical Validity:</b> %{customdata[2]}<extra></extra>'

        fig.add_trace(go.Bar(y=list(range(len(df_filtered))),
                             x=df_filtered['transformed_multiplier'], # Used to be transformer_multiplier
                             orientation='h',
                             marker_color=colors,
                             hovertemplate=hovertemplate,
                             text=df_filtered['risk_multiplier'].values,
                             customdata=list(zip(df_filtered['original_column'], df_filtered['covariate'], df_filtered['statistical_validity'])),
                             showlegend=False),
                      row=row, col=col)
    else:
        ref_value = get_dropped_modality(orig_col)
        title_text = 'Relative Risk Multipliers for <b>{}</b> (Baseline value: {})'.format(orig_col, ref_value)
        fig.layout.annotations[(row-1)*2 + col - 1]['text'] = title_text
        hovertemplate = ' <b>Covariate:</b> %{hovertext}  <br> <b>Risk Multiplier:</b> %{text}<br><b> Statistical Validity:</b> %{customdata}</br><b> Intepretation : </b> ' + orig_col + ' = <b>%{hovertext}</b> multiplies the risk by <b>%{text}</b> (compared to ' + ref_value +')<extra></extra>'

        fig.add_trace(
            go.Bar(
                y=list(range(len(truncated_covariates))),
                x=df_filtered['transformed_multiplier'],
                orientation='h', marker_color=colors,
                hovertemplate = hovertemplate,
                hovertext = shortened_covariates.values,
                text = df_filtered['risk_multiplier'].values,
                customdata = df_filtered['statistical_validity'].values,
                showlegend=False
            ),
            row=row,
            col=col
        )


    fig.update_xaxes(
            range=[
                min(0, 1.1 * df_covariates['transformed_multiplier'].min()),
                max(0, 1.1 * df_covariates['transformed_multiplier'].max()),
            ],
            tickvals=np.log10([0.01, 0.1, 0.5, 1, 2, 10, 100]),
            ticktext=[0.01, 0.1, 0.5, 1, 2, 10, 100],
            row=row, col=col
        )


    # Display at most 10 labels on y-axis
    y_ticks = list(range(len(truncated_covariates)))
    if len(y_ticks) > 10:
        step_size = len(y_ticks) // 10
        y_ticks = y_ticks[::step_size]

    fig.update_yaxes(tickvals=y_ticks,
                     ticktext=[truncated_covariates[i] for i in y_ticks],
                     row=row, col=col)

# Create the plots for each unique original column
numerical_data = []
row, col = 1, 1
for orig_col in unique_orig_cols:
    df_filtered = df_covariates[df_covariates['original_column'] == orig_col]
    shortened_covariates = df_filtered['covariate'].str.replace(f"{orig_col}_", "")
    truncated_covariates = [cov if len(cov) <= 15 else cov[:15]+'...' for cov in shortened_covariates]

    if len(truncated_covariates) <= 1:
        numerical_data.append(df_filtered)
        continue

    create_plot_per_orig_col(df_filtered, orig_col, row, col)
    if col % 2 == 0:
        row += 1
        col = 1
    else:
        col += 1

if numerical_data:
    df_numerical = pd.concat(numerical_data)
    df_numerical = df_numerical.sort_values(ascending=True, by='transformed_multiplier')
    create_plot_per_orig_col(df_numerical, 'All Numerical Variables', row, col)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Adjust height for better space usage, centralize the title
fig.update_layout(height=400*num_of_rows, showlegend=True)
fig.update_layout(title_text="Relative Risk Multipliers", title_x=0.5,title_font=dict(
        size=24,  # adjust as needed
        family="Arial, bold"  # adjust as needed
    ))
fig.update_layout(
    plot_bgcolor="#ecf0f1", # Change plot background color
    paper_bgcolor="white" # Change paper (outside the plot) background color
)

graph_html = "<meta charset='UTF-8'>" + plotly.offline.plot(fig, output_type="div")
insights.save_data(id="Relative_Risks_Dashboard", payload=graph_html, content_type="text/html", label="Relative_Risks_Dashboard")

df_covariates["Considered"] = df_covariates.p < 0.3

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
df_covariates['docstring'] =  cph.__doc__  # Model's docstring
df_covariates['summary'] = cph.summary.to_string()  # Convert coefficients to a list for 1-dimensional storage
df_covariates['baseline_survival'] = cph.baseline_survival_.iloc[-1].tolist()[0]  # Last row of baseline survival function as a list
df_covariates['baseline_cumulative_hazard'] = cph.baseline_cumulative_hazard_.iloc[-1].tolist()[0]  # Last row of baseline cumulative hazard as a list
df_covariates['concordance_index'] = cph.concordance_index_  # Concordance index (C-index)
df_covariates['AIC'] = cph.AIC_partial_  # Akaike Information Criterion (AIC)
df_covariates['penalizer'] = cph.penalizer  # Regularization strength, if used

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Outputs the regression results at the covariates level
covariates_information = dataiku.Dataset("covariates_information")
covariates_information.write_with_schema(df_covariates[df_covariates.p < 0.3])