# Core code imports
import dataiku
import pandas as pd
import numpy as np
import re

# Import dependencies
from bs_commons.dku_utils.pandas_utils.pandas_aggregations import pivot_pandas_column, compute_rownumber

# Constants
N_TOP_FREQUENT_TRANSITIONS_TO_RETRIEVE = 3

# Read recipe inputs
monthly_segment_transitions_and_information_prepared = dataiku.Dataset("monthly_segment_transitions_and_information_prepared")
monthly_segment_transitions_and_information_prepared_df = monthly_segment_transitions_and_information_prepared.get_dataframe()

# Select columns
monthly_segment_transitions_and_information_prepared_df[["previous_rfm_reference_month_start",
                                                         "previous_segment_label",
                                                         "next_segment_label",
                                                         "transition_frequency"
                                                        ]]

# Compute row number
top_next_transitions_df = compute_rownumber(monthly_segment_transitions_and_information_prepared_df,
                                            ["previous_rfm_reference_month_start", "previous_segment_label"],
                                            ["transition_frequency"],
                                            "transition_frequency_rank",
                                            order_columns_to_sort_descending=["transition_frequency"]
                                           )
top_next_transitions_df = top_next_transitions_df[
    top_next_transitions_df["transition_frequency_rank"] <= N_TOP_FREQUENT_TRANSITIONS_TO_RETRIEVE
]
top_next_transitions_df["transition_frequency_rank"] = top_next_transitions_df["transition_frequency_rank"].astype(str)

# Pivot column
top_next_transitions_df = pivot_pandas_column(top_next_transitions_df,
                                              column_to_pivot="transition_frequency_rank",
                                              rows_for_aggregation_key=["previous_rfm_reference_month_start",
                                                                        "previous_segment_label"],
                                              aggregated_columns=["next_segment_label"],
                                              aggregation_function="max",
                                              missing_values_filling=None
                                             )

# Pivot column
monthly_most_likely_next_segments_df = pivot_pandas_column(monthly_segment_transitions_and_information_prepared_df,
                                                           column_to_pivot="next_segment_label",
                                                           rows_for_aggregation_key=["previous_rfm_reference_month_start",
                                                                                     "previous_segment_label"],
                                                           aggregated_columns=["transition_frequency"],
                                                           aggregation_function="max",
                                                           missing_values_filling=0.0
                                                          )

# Merge dataframes
monthly_most_likely_next_segments_df = monthly_most_likely_next_segments_df.merge(top_next_transitions_df,
                                                                                  how="left",
                                                                                  on=["previous_rfm_reference_month_start", "previous_segment_label"]
                                                                                 )

# Rename columns
column_renamings = {}
for column_name in monthly_most_likely_next_segments_df.columns:
    if "transition_frequency_" in column_name:
        new_column_name = re.sub("^transition_frequency_", "", column_name)
        new_column_name = f"{new_column_name}_transition_freq"
        column_renamings[column_name] = new_column_name
monthly_most_likely_next_segments_df.rename(column_renamings, axis=1, inplace=True)

# Fill missing values
for index in range(1, N_TOP_FREQUENT_TRANSITIONS_TO_RETRIEVE+1):
    next_segment_label_column_name = f"next_segment_label_{index}"
    monthly_most_likely_next_segments_df[next_segment_label_column_name].fillna("NO_TRANSITIONS", inplace=True)

# Write recipe outputs
monthly_most_likely_next_segments = dataiku.Dataset("monthly_most_likely_next_segments")
monthly_most_likely_next_segments.write_with_schema(monthly_most_likely_next_segments_df)
