# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from ast import literal_eval

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
from rfm_segmentation.dku_utils import (get_current_project_and_variables,
                                        read_pickle_from_dss_folder,
                                        get_managed_folder_id_with_folder_name,
                                        update_one_schema_column)
from rfm_segmentation.dates_handling import from_datetime_to_dss_string_date
from rfm_segmentation.rfm_packages import (compute_segmentation_quantiles_boundaries, score_rfm_with_quantiles,
                                           train_kmeans_models_for_rfm, score_rfm_with_k_means,
                                           enrich_dataframe_with_rfm_global_scores_and_segments)
from rfm_segmentation.config.flow.constants import (COLUMN_FOR_RECENCY_COMPUTATION,
                                                    COLUMNS_FOR_ARRAYS_AVERAGE_COMPUTATION,
                                                    COLUMNS_FOR_ARRAYS_SUM_COMPUTATION)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
project, variables = get_current_project_and_variables()
scores_computation_technique = variables["standard"]["scores_computation_technique_app"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Read recipe inputs
propagation_customer_rfm_inputs = dataiku.Dataset("propagation_customer_rfm_inputs")
propagation_customer_rfm_inputs_df = propagation_customer_rfm_inputs.get_dataframe()

segments_identification = dataiku.Dataset("rf_segments_identication_synced")
segments_identification_df = segments_identification.get_dataframe()

# Getting the id of the folder containing the RFM scoring information:
rfm_scoring_information_folder_id = get_managed_folder_id_with_folder_name(project, "rfm_scoring_information")
n_segments_per_axis = read_pickle_from_dss_folder("n_segments_per_axis.p", rfm_scoring_information_folder_id)
rfm_original_columns = read_pickle_from_dss_folder("rfm_original_columns.p", rfm_scoring_information_folder_id)
reverse_scores_in_rfm_columns = read_pickle_from_dss_folder("reverse_scores_in_rfm_columns.p", rfm_scoring_information_folder_id)
original_columns_to_rfm_labels_mapping = read_pickle_from_dss_folder("original_columns_to_rfm_labels_mapping.p", rfm_scoring_information_folder_id)
rfm_columns = read_pickle_from_dss_folder("rfm_columns.p", rfm_scoring_information_folder_id)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for column in ["basket_total_amount_concat", "purchased_items_concat"]:
    column_for_sum = COLUMNS_FOR_ARRAYS_SUM_COMPUTATION[column]
    column_for_average = COLUMNS_FOR_ARRAYS_AVERAGE_COMPUTATION[column]
    propagation_customer_rfm_inputs_df[column] = propagation_customer_rfm_inputs_df[column].apply(lambda x: literal_eval(x))
    column_values = list(propagation_customer_rfm_inputs_df[column])

    column_values_sums = []
    column_values_averages = []
    for value in column_values:
        column_values_sums.append(np.sum(value))
        column_values_averages.append(np.mean(value))

    propagation_customer_rfm_inputs_df[column_for_sum] = column_values_sums
    propagation_customer_rfm_inputs_df[column_for_average] = column_values_averages
    propagation_customer_rfm_inputs_df.drop(column, axis=1, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
axes_data = {}
for axis_label in rfm_original_columns:
    axes_data[axis_label] = list(propagation_customer_rfm_inputs_df[axis_label])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if scores_computation_technique == "quantiles":
    axes_quantiles = read_pickle_from_dss_folder("axes_quantiles.p", rfm_scoring_information_folder_id)
    axes_rfm_scores = score_rfm_with_quantiles(rfm_original_columns, axes_data, axes_quantiles, n_segments_per_axis, reverse_scores_in_rfm_columns)

elif scores_computation_technique == "kmeans_clustering":
    axes_kmeans_clustering_models = read_pickle_from_dss_folder("axes_kmeans_clustering_models.p", rfm_scoring_information_folder_id)
    axes_rfm_scores = score_rfm_with_k_means(rfm_original_columns, axes_data, axes_kmeans_clustering_models, reverse_scores_in_rfm_columns)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for axis_label in rfm_original_columns:
    axis_data = axes_data[axis_label]
    rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
    rfm_scores = axes_rfm_scores[axis_label]
    propagation_customer_rfm_inputs_df[rfm_label] = rfm_scores

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
print("Computing the RFM global scores and segments ...")
propagation_customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(propagation_customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value"], "rfm")
propagation_customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(propagation_customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value", "density"], "rfmd")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Formating the date columns in DSS non-ambiguous date format :
DATE_COLUMNS = ["rfm_reference_month_start", "rfm_reference_date", "first_transaction_date", "last_transaction_date"]
for date_column in DATE_COLUMNS:
    date_column_values = list(propagation_customer_rfm_inputs_df[date_column])
    date_column_values = [from_datetime_to_dss_string_date(date) for date in date_column_values]
    propagation_customer_rfm_inputs_df[date_column] = date_column_values

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
propagation_customer_rfm_inputs_df = propagation_customer_rfm_inputs_df.merge(segments_identification_df, how="left", on=["recency", "frequency"])
propagation_customer_rfm_inputs_df.drop("segment_color", axis=1, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs
propagation_customer_rfm_segments = dataiku.Dataset("propagation_customer_rfm_segments")
propagation_customer_rfm_segments.write_dataframe(propagation_customer_rfm_inputs_df, infer_schema=False, dropAndCreate=True)
#propagation_customer_rfm_segments.write_with_schema(propagation_customer_rfm_inputs_df)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Casting the customer_id column datatype to "string" :
update_one_schema_column(project, "propagation_customer_rfm_segments", "customer_id", "string")

# Casting the date columns datatype to "date" :
for date_column in DATE_COLUMNS:
    update_one_schema_column(project, "propagation_customer_rfm_segments", date_column, "date")