# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import dataiku.insights
import pandas as pd, numpy as np

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from sklearn.cluster import KMeans

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from rfm_segmentation.dku_utils import (get_current_project_and_variables, get_managed_folder_info,
                                        get_managed_folder_id_with_folder_name, write_pickle_in_dss_folder,
                                        update_one_schema_column)

from rfm_segmentation.dates_handling import from_datetime_to_dss_string_date

from rfm_segmentation.rfm_packages import (load_rfm_parameters,compute_segmentation_quantiles_boundaries, score_rfm_with_quantiles,
                                           train_kmeans_models_for_rfm, score_rfm_with_k_means,
                                           enrich_dataframe_with_rfm_global_scores_and_segments,
                                           remove_dataframe_outliers_based_on_quantiles, generate_rfm_box_plots)

from rfm_segmentation.config.flow.constants import (COLUMN_FOR_RECENCY_COMPUTATION,
                                                    LOWER_OUTLIERS_QUANTILE_TRESHOLD,
                                                    HIGHER_OUTLIERS_QUANTILE_TRESHOLD)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Read recipe inputs
customer_rfm_inputs = dataiku.Dataset("customer_rfm_inputs")
customer_rfm_inputs_df = customer_rfm_inputs.get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
segments_identification = dataiku.Dataset("rf_segments_identication_synced")
segments_identification_df = segments_identification.get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
project, variables = get_current_project_and_variables()
global_variables = variables["standard"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
scores_computation_technique = global_variables["scores_computation_technique_app"]
n_segments_per_axis = global_variables["n_segments_per_axis_app"]
recency_policy = global_variables["recency_policy_app"]
monetary_value_policy = global_variables["monetary_value_policy_app"]

rfm_original_columns, reverse_scores_in_rfm_columns, original_columns_to_rfm_labels_mapping, rfm_columns =\
load_rfm_parameters(COLUMN_FOR_RECENCY_COMPUTATION, recency_policy, monetary_value_policy)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
customer_rfm_inputs_df = remove_dataframe_outliers_based_on_quantiles(customer_rfm_inputs_df,
                                                                      LOWER_OUTLIERS_QUANTILE_TRESHOLD,
                                                                      HIGHER_OUTLIERS_QUANTILE_TRESHOLD,
                                                                      rfm_original_columns)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
axes_data = {}
for axis_label in rfm_original_columns:
    axes_data[axis_label] = list(customer_rfm_inputs_df[axis_label])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if scores_computation_technique == "quantiles":
    print("Computing RFM axes scores with quantiles ...")
    segmentation_quantiles_boundaries = compute_segmentation_quantiles_boundaries(n_segments_per_axis)
    axes_quantiles = {}
    for axis_label in rfm_original_columns:
        axis_data = axes_data[axis_label]
        axis_quantiles = np.quantile(axis_data, q=segmentation_quantiles_boundaries)
        axes_quantiles[axis_label] = axis_quantiles

    axes_rfm_scores = score_rfm_with_quantiles(rfm_original_columns, axes_data, axes_quantiles, n_segments_per_axis, reverse_scores_in_rfm_columns)

elif scores_computation_technique == "kmeans_clustering":
    print("Computing RFM axes scores with KMeans clustering ...")
    axes_kmeans_clustering_models = train_kmeans_models_for_rfm(rfm_original_columns, axes_data, n_segments_per_axis)
    axes_rfm_scores = score_rfm_with_k_means(rfm_original_columns, axes_data, axes_kmeans_clustering_models, reverse_scores_in_rfm_columns)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for axis_label in rfm_original_columns:
    axis_data = axes_data[axis_label]
    rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
    rfm_scores = axes_rfm_scores[axis_label]
    customer_rfm_inputs_df[rfm_label] = rfm_scores

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Computing the RFM global scores and segments:
print("Computing the RFM global scores and segments ...")
customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value"], "rfm")
customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value", "density"], "rfmd")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Saving the RFM scoring information:
rfm_scoring_information_folder_id = get_managed_folder_id_with_folder_name(project, "rfm_scoring_information")

write_pickle_in_dss_folder(n_segments_per_axis, "n_segments_per_axis", rfm_scoring_information_folder_id)
write_pickle_in_dss_folder(rfm_original_columns, "rfm_original_columns", rfm_scoring_information_folder_id)
write_pickle_in_dss_folder(reverse_scores_in_rfm_columns, "reverse_scores_in_rfm_columns", rfm_scoring_information_folder_id)
write_pickle_in_dss_folder(original_columns_to_rfm_labels_mapping, "original_columns_to_rfm_labels_mapping", rfm_scoring_information_folder_id)
write_pickle_in_dss_folder(rfm_columns, "rfm_columns", rfm_scoring_information_folder_id)

if scores_computation_technique == "quantiles":
    write_pickle_in_dss_folder(axes_quantiles, "axes_quantiles", rfm_scoring_information_folder_id)
elif scores_computation_technique == "kmeans_clustering":
    write_pickle_in_dss_folder(axes_kmeans_clustering_models, "axes_kmeans_clustering_models", rfm_scoring_information_folder_id)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Formating the date columns in DSS non-ambiguous date format :
DATE_COLUMNS = ["rfm_reference_month_start", "rfm_reference_date", "first_transaction_date", "last_transaction_date"]
for date_column in DATE_COLUMNS:
    date_column_values = list(customer_rfm_inputs_df[date_column])
    date_column_values = [from_datetime_to_dss_string_date(date) for date in date_column_values]
    customer_rfm_inputs_df[date_column] = date_column_values

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
customer_rfm_inputs_df = customer_rfm_inputs_df.merge(segments_identification_df, how="left", on=["recency", "frequency"])
customer_rfm_inputs_df.drop("segment_color", axis=1, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs:
customer_rfm_segments = dataiku.Dataset("customer_rfm_segments")
customer_rfm_segments.write_dataframe(customer_rfm_inputs_df, infer_schema=False, dropAndCreate=True)
#customer_rfm_segments.write_with_schema(customer_rfm_inputs_df)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Casting the customer_id column datatype to "string" :
update_one_schema_column(project, "customer_rfm_segments", "customer_id", "string")

# Casting the date columns datatype to "date" :
for date_column in DATE_COLUMNS:
    update_one_schema_column(project, "customer_rfm_segments", date_column, "date")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Ploting RFM boxplots :
rfm_box_plots = generate_rfm_box_plots(rfm_original_columns, original_columns_to_rfm_labels_mapping,
                                       customer_rfm_inputs_df, n_segments_per_axis)

# Saving custom charts as static insights :
project_key = dataiku.get_custom_variables()["projectKey"]
for static_insight_id in rfm_box_plots.keys():
    plotly_figure = rfm_box_plots[static_insight_id]
    dataiku.insights.save_plotly(id=static_insight_id, figure=plotly_figure, project_key=project_key)