import dataiku
import pandas as pd
import numpy as np
from flask import request
import json

from rfm_segmentation.dku_utils import get_current_project_and_variables
from rfm_segmentation.webapp import (aggregate_rfm_dataframes_by_segments, aggregate_rfm_dataframes_by_monetary_value_and_segment,
                                     root_treemap_data, convert_list_values_to_int, bold_string_for_html,
                                     match_segment_labels_and_values, TREEMAP_ROOT_COLOR, MONETARY_VALUE_COLORS)


segments_identification_df = dataiku.Dataset("rf_segments_identication_synced").get_dataframe()
segments_identification_df = segments_identification_df[["segment_label", "segment_color"]].drop_duplicates()

customer_rfm_segments_df = dataiku.Dataset("customer_rfm_segments").get_dataframe()
customer_rfm_segments_df = customer_rfm_segments_df.merge(segments_identification_df, how="left", on="segment_label")

project, variables = get_current_project_and_variables()
app_variables = variables["standard"]
propagate_rfm = app_variables["propagate_rfm_app"]

if propagate_rfm:
    most_recent_dates_cutomers_df = dataiku.Dataset("last_customer_rfm_segments_recent_dates").get_dataframe()
    inactive_customers_df = dataiku.Dataset("inactive_customers").get_dataframe()
    
    most_recent_dates_cutomers_df = most_recent_dates_cutomers_df.merge(segments_identification_df, how="left", on="segment_label")
    inactive_customers_df = inactive_customers_df.merge(segments_identification_df, how="left", on="segment_label")

rfm_datasource_data = {
    "reference_period_cutomers": customer_rfm_segments_df
}

if propagate_rfm:
    rfm_datasource_data["most_recent_dates_cutomers"] = most_recent_dates_cutomers_df
    rfm_datasource_data["inactive_customers"] = inactive_customers_df

rfm_datasource_by_segment_data = {}
rfm_datasource_by_monetary_value_and_segment_data = {}
rfm_datasource_total_customers = {}
for rfm_datasource in rfm_datasource_data.keys():
    data_df = rfm_datasource_data[rfm_datasource]
    rfm_datasource_by_segment_data[rfm_datasource] = aggregate_rfm_dataframes_by_segments(data_df)
    rfm_datasource_by_monetary_value_and_segment_data[rfm_datasource] = aggregate_rfm_dataframes_by_monetary_value_and_segment(data_df)
    rfm_datasource_total_customers[rfm_datasource] = int(np.sum(data_df["count"]))

@app.route("/send_app_features/", methods=["POST"])
def send_app_features():
    front_parameters = request.get_json(force=True)
    print("Front parameters : {}".format(front_parameters))
    exploration_strategy = front_parameters["exploration_strategy"]
    available_monetary_values = front_parameters["available_monetary_values"]
    rfm_datasource = front_parameters["rfm_datasource"]
    by_segment_df = rfm_datasource_by_segment_data[rfm_datasource]
    by_monetary_value_and_segment_df = rfm_datasource_by_monetary_value_and_segment_data[rfm_datasource]
    total_customers = rfm_datasource_total_customers[rfm_datasource]
    treemaps_data = []
    
    if exploration_strategy == "whole_data":
        segment_labels = list(by_segment_df["segment_label"])
        segment_labels = [bold_string_for_html(segment_label) for segment_label in segment_labels]
        segment_values = list(by_segment_df["count"]) 
        segment_labels = match_segment_labels_and_values(segment_labels, segment_values)
        segment_colors = list(by_segment_df["segment_color"])
        n_customers = np.sum(segment_values)
        
        treemap_main_information = bold_string_for_html("Segments on whole data")
        treemap_root_label = "{} : {} customers".format(treemap_main_information, n_customers)
        treemap_labels, treemap_parents, treemap_values, treemap_colors =\
        root_treemap_data(treemap_root_label, TREEMAP_ROOT_COLOR, segment_labels, segment_values, segment_colors)
        treemaps_data.append({"labels": treemap_labels, "parents": treemap_parents, "values": treemap_values, "colors": treemap_colors})
        
    
    elif exploration_strategy == "monetary_value_relative_importance":
        treemap_parents = [""]
        treemap_main_information = bold_string_for_html("Segments by monetary value relative importance")
        treemap_root_label = "{} : {} customers".format(treemap_main_information, total_customers)
        treemap_labels = [treemap_root_label]
        treemap_values = [total_customers]
        treemap_colors = [TREEMAP_ROOT_COLOR]
        monetary_values_data = {}
        for monetary_value in available_monetary_values:
            monetary_value_branch_name = "Monetary value {}".format(monetary_value)
            monetary_value_branch_name = bold_string_for_html(monetary_value_branch_name)
            filtered_df = by_monetary_value_and_segment_df[by_monetary_value_and_segment_df["monetary_value"] == monetary_value]
            monetary_value_segment_labels = list(filtered_df["segment_label"])
            monetary_value_segment_labels = ["mv {} / {}".format(monetary_value, bold_string_for_html(segment_label)) for segment_label in monetary_value_segment_labels]
            monetary_value_segment_values = list(filtered_df["count"])
            monetary_value_segment_labels = match_segment_labels_and_values(monetary_value_segment_labels, monetary_value_segment_values)
            monetary_value_segment_colors = list(filtered_df["segment_color"])
            
            n_customers = int(np.sum(monetary_value_segment_values))
            monetary_value_branch_name = "{} : {}".format(monetary_value_branch_name, n_customers)
            n_monetary_value_segments = len(monetary_value_segment_labels)
            
            treemap_parents += [treemap_root_label] + [monetary_value_branch_name for __ in range(n_monetary_value_segments)]
            treemap_labels += [monetary_value_branch_name] + monetary_value_segment_labels
            treemap_values += [n_customers] + monetary_value_segment_values
            treemap_values = convert_list_values_to_int(treemap_values)
            treemap_colors += [MONETARY_VALUE_COLORS[monetary_value]] + monetary_value_segment_colors
            
            treemaps_data.append({"labels": treemap_labels, "parents": treemap_parents, "values": treemap_values, "colors": treemap_colors})
        
    elif exploration_strategy == "split_by_monetary_value" :
        for monetary_value in available_monetary_values:        
            filtered_df = by_monetary_value_and_segment_df[by_monetary_value_and_segment_df["monetary_value"] == monetary_value]
            segment_labels = list(filtered_df["segment_label"])
            segment_labels = [bold_string_for_html(segment_label) for segment_label in segment_labels]
            segment_values = list(filtered_df["count"])
            segment_labels = match_segment_labels_and_values(segment_labels, segment_values)
            segment_colors = list(filtered_df["segment_color"])
            n_customers = int(np.sum(segment_values))
            treemap_main_information = bold_string_for_html("Segments on monetary value {}".format(monetary_value))
            treemap_root_label = "{} : {} customers".format(treemap_main_information, n_customers)
            treemap_labels, treemap_parents, treemap_values, treemap_colors =\
            root_treemap_data(treemap_root_label, TREEMAP_ROOT_COLOR, segment_labels, segment_values, segment_colors)
            treemaps_data.append({"labels": treemap_labels, "parents": treemap_parents, "values": treemap_values, "colors": treemap_colors})

    print("Front parameters loaded !")
    response = {
        "propagate_rfm": propagate_rfm,
        "data": treemaps_data}    
    return json.dumps(response)