import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from plotly.graph_objs import Layout

from clv_forecast.utils.generic import reformat_group

PLOTLY_AXES_FONT_SIZE = 10
PLOTLY_AXES_LABEL_COLOR = "rgb(0, 0, 0)"
PLOTLY_Y_AXIS_GRID_COLOR = "grey"
PLOTLY_PLOT_BACKGROUND_COLOR = "rgb(255, 255, 255)"
PLOTLY_PLOT_BORDERS_COLOR = "rgb(255, 255, 255)"
PLOTLY_FIGURE_HEIGHT = 400
PLOTLY_FIGURE_WIDTH = 800


def log_transform_for_min_max_scaling(list_or_array):
    list_or_array_min = min(list_or_array)
    list_or_array = [np.log(1 + value - list_or_array_min) for value in list_or_array]
    return list_or_array


def min_max_scale_values(list_or_array, smooth_scaling_with_log):
    if smooth_scaling_with_log:
        list_or_array = log_transform_for_min_max_scaling(list_or_array)
        pass

    list_or_array_min = min(list_or_array)
    list_or_array_max = max(list_or_array)
    scaled_values = []
    return [
        (value - list_or_array_min) / (list_or_array_max - list_or_array_min)
        for value in list_or_array
    ]


def from_scaled_values_to_hex_colors(scaled_values, cmap_id, bool_keep_alpha_in_rgba):
    """
    values : should be scaled
    cmap_id : look at https://matplotlib.org/stable/gallery/color/colormap_reference.html
    """
    color_mapper = plt.get_cmap(cmap_id)
    rgba_colors = color_mapper(scaled_values, bytes=False)
    hex_colors = [
        matplotlib.colors.to_hex(color, keep_alpha=bool_keep_alpha_in_rgba)
        for color in rgba_colors
    ]
    return hex_colors


def create_plotly_boxplot_trace(y_data, trace_color, trace_name):
    trace = go.Box(y=y_data, marker_color=trace_color, name=trace_name, boxpoints=False)
    return trace


def create_plotly_boxplot_layout(
    x_axis_label,
    x_axis_label_color,
    x_axis_font_size,
    y_axis_label,
    y_axis_label_color,
    y_axis_font_size,
    y_axis_grid_color,
    plot_background_color,
    plot_borders_color,
    figure_height,
    figure_width,
):
    layout = Layout(
        plot_bgcolor=plot_background_color,
        paper_bgcolor=plot_borders_color,
        xaxis=dict(
            showgrid=False,
            zeroline=False,
            showticklabels=False,
            title=x_axis_label,
            titlefont=dict(size=x_axis_font_size, color=x_axis_label_color),
        ),
        yaxis=dict(
            zeroline=False,
            gridcolor=y_axis_grid_color,
            title=y_axis_label,
            titlefont=dict(size=y_axis_font_size, color=y_axis_label_color),
        ),
        height=figure_height,
        width=figure_width,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
    )
    return layout


def create_plotly_boxplot_figure(boxplot_data, boxplot_layout):
    figure = go.Figure(data=boxplot_data, layout=boxplot_layout)
    return figure


def generate_rfm_box_plots(
    rfm_original_columns,
    original_columns_to_rfm_labels_mapping,
    final_rfm_dataframe,
    n_segments_per_axis,
    quantile_max=0.80,
    quantile_min=0.20,
):
    rfm_box_plots = {}
    rfm_scores = list(range(1, n_segments_per_axis + 1))
    rfm_scores.append(
        n_segments_per_axis + 1
    )  # We add a value to the RFM scores to make the higher score color more visible
    rfm_scores_scaled = min_max_scale_values(rfm_scores, False)
    COLORMAP_ID = "magma"
    rfm_scores_colors = from_scaled_values_to_hex_colors(
        rfm_scores_scaled, COLORMAP_ID, False
    )
    # rfm_scores_colors.reverse()
    scores_to_colors = {
        score: color for score, color in zip(rfm_scores, rfm_scores_colors)
    }

    for column in rfm_original_columns:
        rfm_axis_label = original_columns_to_rfm_labels_mapping[column]
        rfm_axis_color = "{}_color".format(rfm_axis_label)
        rfm_scores_labels = {
            score: "{} {}".format(rfm_axis_label, score) for score in rfm_scores
        }
        rfm_axis_scores_data = {}

        for rfm_score in rfm_scores:
            rfm_axis_scores_data[rfm_score] = list(
                final_rfm_dataframe[column][
                    final_rfm_dataframe[rfm_axis_label] == reformat_group(rfm_score - 1)
                ],
            )

        boxplot_data = [
            create_plotly_boxplot_trace(
                rfm_axis_scores_data[rfm_score],
                scores_to_colors[rfm_score],
                rfm_scores_labels[rfm_score],
            )
            for rfm_score in rfm_scores
        ]
        boxplot_layout = create_plotly_boxplot_layout(
            rfm_axis_label,
            PLOTLY_AXES_LABEL_COLOR,
            PLOTLY_AXES_FONT_SIZE,
            column,
            PLOTLY_AXES_LABEL_COLOR,
            PLOTLY_AXES_FONT_SIZE,
            PLOTLY_Y_AXIS_GRID_COLOR,
            PLOTLY_PLOT_BACKGROUND_COLOR,
            PLOTLY_PLOT_BORDERS_COLOR,
            PLOTLY_FIGURE_HEIGHT,
            PLOTLY_FIGURE_WIDTH,
        )
        figure = create_plotly_boxplot_figure(boxplot_data, boxplot_layout)
        rfm_box_plots["{}_box_plot".format(rfm_axis_label)] = figure
    return rfm_box_plots
