import dataiku
import dash
from dash import dcc, html, dash_table, no_update  # added no_update
from dash.dependencies import Input, Output, State, ALL
import sklearn
import io
import pickle
import base64  # used when serializing kmeans pipeline
import pandas as pd  # used extensively as pd
from flask import request
import dash_bootstrap_components as dbc
from main_functions import *
from layout_components_style import *

table_styles = get_table_styles()
button_style = get_button_style()
button_style_space = {**button_style, "margin-right": "7px"}
visible_tab_style = get_visible_tab_style()
hidden_style = {'display': 'none'}
enter_box_style = get_enter_box_style()
tooltip_style = get_tooltip_style()
question_mark_style = get_question_mark_style()
header_h2_style = get_header_h2_style()
subheader_style = get_subheader_style()
description_style = get_description_style()
dropdown_style = get_dropdown_style()
main_style = get_main_style()
left_panel_style = get_left_panel_style()
right_panel_style = get_right_panel_style()
tab_style = get_tab_style()
tab_selected_style = get_tab_selected_style()
div_message_style = get_div_message_style()
refresh_button_style = get_refresh_button_style()



TAB_ORDER = ['scoping-tab', 'building-tab', 'results-tab']

def right_tab_nav(tab_value):
    """Prev/Next controls shown inside each right tab."""
    idx = TAB_ORDER.index(tab_value)
    has_prev = idx > 0
    has_next = idx < len(TAB_ORDER) - 1  # last tab (results-tab) → False

    return html.Div(
        [
            html.Button(
                "◀ Previous",
                id=f"nav-prev-{tab_value}",
                n_clicks=0,
                style={**button_style_space, 'visibility': 'visible' if has_prev else 'hidden'}
            ),
            html.Button(
                "Next ▶",
                id=f"nav-next-{tab_value}",
                n_clicks=0,
                style={**button_style_space, 'visibility': 'visible' if has_next else 'hidden'}
            ),
         #   dbc.Tooltip("Go to the previous step.",
         #               target=f"nav-prev-{tab_value}", placement="under", style=tooltip_style),
          #  dbc.Tooltip("Go to the next step.",
           #             target=f"nav-next-{tab_value}", placement="right", style=tooltip_style),
        ],
        style={'display': 'flex', 'justifyContent': 'flex-end', 'gap': '8px', 'margin': '10px 0'}
    )


# Fixed-position RESTART button in top-right corner
restart_button_div = html.Div([
    html.Button("REFRESH", id="clear-icon", style={**button_style, 'padding': '4px 10px'}),
    dbc.Tooltip(
        "Reset the interface to its default state. This won't delete any saved sessions.",
        target="clear-icon",
        placement="left",
        style=tooltip_style
    ),
    html.Div(id='refresh-trigger', style={'display': 'none'})  # Keep store close to the button
], style=refresh_button_style)

app.layout = html.Div([
    restart_button_div,
    # Main container using flexbox to align left and right panels
    html.Div([
        # Left Panel with dcc.Tabs (content changes depending on the seed tab)
        html.Div([
            dcc.Tabs(id='left-panel-tabs', value='scoping-tab',
                     children=[
                         dcc.Tab(
                             label='', value='scoping-tab',
                             children=[
                                 html.Div([
                                     html.H2(['Select Dataset in Dataiku Flow',
                                              html.Span('?', id='tooltip-select-dataset', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "Select a dataset and make sure that it includes a column account_id with other features.",
                                         target="tooltip-select-dataset",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     dcc.Dropdown(
                                         id='select-dataset',
                                         placeholder="Select a dataset",
                                         style=dropdown_style
                                     ),
                                     dcc.Store(id='loaded-dataset'),  # Store dataset
                                     html.H2(['Filtering',
                                              html.Span('?', id='tooltip-filtering', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "Build your cohorts by applying the filters below. Once you select features and APPLY, the dataset on the right will be updated accordingly. If the select feature dropdown is empty, the APPLY button displays the entire original dataset.",
                                         target="tooltip-filtering",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     dcc.Dropdown(
                                         id="filter-column-dropdown",
                                         multi=True,
                                         placeholder="Select columns to filter",
                                         style=dropdown_style,
                                     ),
                                     html.Div(id="filter-inputs"),
                                     html.Div([
                                         html.Button("Apply", id="apply-filters", style=button_style_space),
                                     ], style={'margin-top': '10px'}),
                                     dcc.Store(id="filtered-data"),
                                 ])
                             ],
                             style=hidden_style,
                             selected_style=visible_tab_style
                         ),
                         dcc.Tab(
                             label='', value='building-tab',
                             children=[
                                 html.Div([
                                     html.H2(['Select a Method',
                                              html.Span('?', id='tooltip-select-method', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "Select a method to display the description and different selection parameters.",
                                         target="tooltip-select-method",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     html.Div([
                                         html.Div([
                                             html.Img(src=f"data:image/png;base64,{kmeans_image_encoded}", id="kmeans-icon",
                                                      style={'cursor': 'pointer', 'max-width': '15%', 'height': 'auto', 'border': '1px solid #ccc',
                                                             'display': 'block', 'margin-left': 'auto', 'margin-right': 'auto'}),
                                             html.Label("Machine Learning Clustering", style={'text-align': 'center', 'fontSize': '12px', 'margin-top': '5px', 'min-height': '35px'})
                                         ], style={'text-align': 'center', 'flex': '1', 'margin': '0 0px'}),

                                         html.Div([
                                             html.Img(src=f"data:image/png;base64,{rule_based_image_encoded}", id="rule-based-icon",
                                                      style={'cursor': 'pointer', 'max-width': '15%', 'height': 'auto', 'border': '1px solid #ccc',
                                                             'display': 'block', 'margin-left': 'auto', 'margin-right': 'auto'}),
                                             html.Label("Rule-Based Segmentation", style={'text-align': 'center', 'fontSize': '12px', 'margin-top': '5px', 'min-height': '35px'})
                                         ], style={'text-align': 'center', 'flex': '1', 'margin': '0 0px'}),

                                     ], style={'display': 'flex', 'justify-content': 'center', 'align-items': 'flex-start', 'flex-wrap': 'wrap', 'width': '100%', 'gap': '0px'}),

                                     html.Div(id="method-description", style={'marginTop': '20px', 'fontSize': '14px', 'color': '#333', 'textAlign': 'center'}),

                                     dcc.Store(id="clustering-method-store"),

                                     html.H2(['Parameters Selection',
                                              html.Span('?', id='tooltip-select-parameters', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "The number of segments should be an integer positive number of at least 2. The option for Rule-Based Segmentation features is limited to only numerical ones, and the weight options should all set to positive integer numbers that control the importance of each feature in the segmentation process.",
                                         target="tooltip-select-parameters",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     html.Div(id="method-selection"),
                                     # Number of Clusters for KMeans
                                     html.Div(id="num-clusters-container", children=[
                                         html.Label("Number of Segments", style=subheader_style),
                                         dcc.Input(id="num-clusters", type="number", value=3, style=enter_box_style)
                                     ], style={'margin-bottom': '10px', 'fontSize': '14px', 'display': 'none'}),

                                     # Number of Bins for Rule-Based method
                                     html.Div(id="rule-based-parameters", children=[
                                         html.Label("Number of Segments", style=subheader_style),
                                         dcc.Input(id="num-bins", type="number", value=3, min=2, style=enter_box_style),
                                     ], style={'margin-bottom': '10px', 'fontSize': '14px', 'display': 'none'}),

                                     # Feature Selection for both methods
                                     html.Div([html.Label("Select Features", style=subheader_style),
                                               dcc.Dropdown(id="feature-selection", multi=True, style=dropdown_style),
                                               ], id='feature-selection-container', style={'margin-bottom': '10px', 'fontSize': '14px', 'display': 'none'}),

                                     # Dynamic weights for Rule-Based method
                                     html.Div(id="weights-container", style={'margin-bottom': '10px'}),
                                     html.Button("Run", id="apply-segmentation", style=button_style),
                                 ])
                             ],
                             style=hidden_style,
                             selected_style=visible_tab_style
                         ),
                         dcc.Tab(
                             label='', value='results-tab',
                             children=[
                                 html.Div([
                                     html.H2(['Remap Segment Names',
                                              html.Span('?', id='tooltip-remap-segments', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "Review the segmentation insights analysis and remap the segment names below to give a meaningful interpretation to your results.",
                                         target="tooltip-remap-segments",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     html.Div(id="mapping-container"),
                                     html.Button("Remap Segments", id="remap-button", n_clicks=0, style=button_style),
                                     dcc.Store(id='remap-cluster-names-store'),
                                     dcc.Store(id='remap-cluster-dict'),
                                     html.Div(id="remap-message"),
                                     # Stores
                                     dcc.Store(id='method-changed-store', data=False),
                                     dcc.Store(id='filter-dictionary-store'),
                                     dcc.Store(id='kmeans-pipeline-store'),
                                     dcc.Store(id='rulebased-bin-bounds'),
                                     dcc.Store(id='rulebased-weights'),

                                     # Export and Save
                                     html.H2(['Segmentation Processing',
                                              html.Span('?', id='tooltip-segments-process', style=question_mark_style)
                                              ], style=header_h2_style),
                                     dbc.Tooltip(
                                         "If you want to SAVE this segmentation session you must provide a name and segmentation description. These information are particularly useful metadata for future updates.",
                                         target="tooltip-segments-process",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     dcc.Input(id='session-name', type='text', placeholder='Enter a session name', style=enter_box_style),
                                     dcc.Textarea(
                                         id='metadata-description',
                                         placeholder='Enter a segmentation description..',
                                         style=description_style
                                     ),
                                     dcc.Store(id='description-dictionary'),

                                     html.Div([
                                         html.Button("Save", id="save-icon", style=button_style_space),
                                         dbc.Tooltip("SAVE button will create a csv file with the new Segmentation Results in the output_data_folder and a new record in the metadata_dataset with your updated session information.",
                                                     target="save-icon",
                                                     placement="bottom",
                                                     style=tooltip_style
                                                     ),
                                         html.Button("Export", id="export-icon", style=button_style_space),
                                         dbc.Tooltip("EXPORT button will download a csv file with the segmentation resuts locally.",
                                                     target="export-icon",
                                                     placement="bottom",
                                                     style=tooltip_style
                                                     ),

                                         html.Span('?', id='tooltip-buttons', style=question_mark_style)
                                     ]),
                                     dbc.Tooltip(
                                         "Hover over the buttons to see functionality details.",
                                         target="tooltip-buttons",
                                         placement="right",
                                         style=tooltip_style
                                     ),
                                     html.Div(id="error-message", style={'color': 'black'}),
                                     dcc.Download(id="download-dataframe-csv"),
                                     html.Div(id='save-output'),
                                 ])
                             ],
                             style=hidden_style,
                             selected_style=visible_tab_style
                         ),
                     ], style={'padding': '10px'}),  # Tabs for left panel content
        ], style=left_panel_style),

        # Right Panel (Tabs at the top of the right panel)
        html.Div([
            dcc.Tabs(id='right-panel-tabs', value='scoping-tab', children=[
                dcc.Tab(
                    label='1. Filtering',
                    value='scoping-tab',
                    style=tab_style, selected_style=tab_selected_style,
                    children=[
                        html.Div([
                            right_tab_nav('scoping-tab'),  # ← nav placed at top
                            html.H2(['Data for Segmentation',
                                     html.Span('?', id='tooltip-data-segmentation', style=question_mark_style)
                                     ], style=header_h2_style),
                            dbc.Tooltip(
                                "Once you select a dataset the metric below will show that name of the dataset and the number of records updating with the filtering conditions.",
                                target="tooltip-data-segmentation",
                                placement="right",
                                style=tooltip_style
                            ),
                            html.H5(id='dataset-title'),
                            dcc.Loading(
                                id="loading-data-upload",
                                type="circle",
                                children=[
                                    html.Div(id='output-data-upload', children=[
                                        dash_table.DataTable(
                                            id='output-data-upload-table',
                                            page_size=15,
                                            sort_action='native',
                                            filter_action='native',
                                            page_action='native',
                                            page_current=0,
                                            **table_styles
                                        )
                                    ])
                                ]
                            ),
                        ], style={'padding': '20px'})
                    ]
                ),
                dcc.Tab(
                    label='2. Building', value='building-tab',
                    style=tab_style, selected_style=tab_selected_style,
                    children=[
                        html.Div([
                            right_tab_nav('building-tab'),  # ← nav placed at top
                            html.Div([
                                dcc.Interval(id='progress-interval', interval=1000, n_intervals=0, disabled=True),
                            ], style={'margin-top': '20px'}),

                            html.Div(id="output-div"),

                            html.Div([
                                html.H2(['Segmentation Results',
                                         html.Span('?', id='tooltip-segmentation-results', style=question_mark_style)
                                         ], style=header_h2_style),
                                dbc.Tooltip(
                                    "Once you select a method and RUN segmentation, you will see below the filtered data table, with an extra cluster column with the segmentation results.",
                                    target="tooltip-segmentation-results",
                                    placement="right",
                                    style=tooltip_style
                                ),
                                html.Div(
                                    "Check the new column 'cluster' that shows how data are segmented into different groups based on the specified number of segments. Explore and interpret these groups in the Insights tab.",
                                    id="segmentation-help-message",
                                    style={**div_message_style, 'display': 'none', 'marginBottom': '15px'}
                                ),

                                dcc.Loading(
                                    id="loading-table",
                                    type="circle",
                                    children=[
                                        html.Div([
                                            html.Div(
                                                dash_table.DataTable(
                                                    id='output-table',
                                                    sort_action='native',
                                                    filter_action='native',
                                                    page_size=15,
                                                    **table_styles
                                                ),
                                                id='table-container',
                                                style={'display': 'none'}
                                            ),
                                            html.Div(
                                                id="no-data-message",
                                                style={'textAlign': 'center', 'fontSize': 16, 'color': 'gray'}
                                            )
                                        ])
                                    ]
                                ),
                            ], style={'margin-top': '20px'}),

                        ], style={'padding': '20px'})
                    ]
                ),
                dcc.Tab(
                    label='3. Insights', value='results-tab',
                    style=tab_style, selected_style=tab_selected_style,
                    children=[
                        html.Div([
                            right_tab_nav('results-tab'),  # ← nav placed at top
                            html.H2(['Segmentation Insights',
                                     html.Span('?', id='tooltip-segmentation-insights', style=question_mark_style)
                                     ], style=header_h2_style),
                            dbc.Tooltip(
                                "Once you built and RUN a segmentation session, the graphs below will support the explainability of your results.",
                                target="tooltip-segmentation-insights",
                                placement="right",
                                style=tooltip_style
                            ),

                            html.Div([
                                html.Div([
                                    html.H3("Number of Records in Each Segment", style=subheader_style),
                                    dcc.Graph(id='cluster-histogram'),
                                ], style={'width': '49%', 'display': 'inline-block'}),
                                html.Div([
                                    html.H3(['Variable Importance',
                                             html.Span('?', id='tooltip-variable-importance', style=question_mark_style)
                                             ], style=subheader_style),
                                    dbc.Tooltip(
                                        "For machine learning, variable importance identifies which features most influence the formation of segments. Higher values (closer to 1) mean the feature contributed more to distinguishing between clusters and lower values (near 0) mean the feature had little impact in splitting the data for classification. In rule-based segmentation, feature weights directly control how much each feature contributes to the segmentation outcome.",
                                        target="tooltip-variable-importance",
                                        placement="right",
                                        style=tooltip_style
                                    ),
                                    dcc.Graph(id='importance-histogram'),
                                ], style={'width': '49%', 'display': 'inline-block', 'float': 'right'}),
                            ]),

                            html.H2(['Feature Distribution Analysis',
                                     html.Span('?', id='tooltip-feature-dist', style=question_mark_style)
                                     ], style=header_h2_style),
                            dbc.Tooltip(
                                "In case you update the segmentation method or features and rerun the segmentation in the previous tab, you should deselect and select the features of your interest below to update the displayed results.",
                                target="tooltip-feature-dist",
                                placement="right",
                                style=tooltip_style
                            ),
                            html.Div([
                                html.Div([
                                    html.H3("Feature Distribution by Segment", style=subheader_style),
                                    dcc.Dropdown(id='feature-dropdown', placeholder="Select a feature to display", style=dropdown_style),
                                    html.Div(id="feature-histograms"),
                                ], style={'width': '49%', 'display': 'inline-block'}),
                                html.Div([
                                    html.H3(["Average Feature Values by Segment",
                                             html.Span('?', id='tooltip-heatmap', style=question_mark_style)],
                                            style=subheader_style),
                                    dbc.Tooltip(
                                        "Compare how each selected numerical feature differs across segments. Darker colors indicate relatively higher averages in that segment. Hover over a cell to view the actual average value.",
                                        target="tooltip-heatmap",
                                        placement="right",
                                        style=tooltip_style
                                    ),
                                    dcc.Dropdown(id='numerical-feature-multiselect', multi=True, placeholder="Select numerical features", style=dropdown_style),
                                    html.Div(id='mean-heatmap'),
                                ], style={'width': '49%', 'display': 'inline-block', 'float': 'right'}),
                            ]),
                            html.Div(id="message-div", style=subheader_style)
                        ], style={'padding': '20px'})
                    ]
                ),
            ]),
        ], style=right_panel_style),

    ], style=main_style)
])

# ---------- Unified router for tab syncing + Prev/Next navigation (no duplicate outputs) ----------
def _step(order, current, delta):
    i = order.index(current)
    j = max(0, min(len(order) - 1, i + delta))
    return order[j]

@app.callback(
    [Output('left-panel-tabs', 'value'),
     Output('right-panel-tabs', 'value')],
    [
        # Header tab clicks
        Input('left-panel-tabs', 'value'),
        Input('right-panel-tabs', 'value'),
        # Prev/Next buttons from each right tab
        Input('nav-prev-scoping-tab', 'n_clicks'),
        Input('nav-prev-building-tab', 'n_clicks'),
        Input('nav-prev-results-tab', 'n_clicks'),
        Input('nav-next-scoping-tab', 'n_clicks'),
        Input('nav-next-building-tab', 'n_clicks'),
        Input('nav-next-results-tab', 'n_clicks'),
    ],
    State('right-panel-tabs', 'value'),
    prevent_initial_call=True
)
def route_tabs(left_val, right_val,
               prev_scoping, prev_building, prev_results,
               next_scoping, next_building, next_results,
               current_right):
    ctx = dash.callback_context
    if not ctx.triggered:
        return no_update, no_update

    trig = ctx.triggered[0]['prop_id'].split('.')[0]

    # Header clicks → mirror
    if trig == 'left-panel-tabs':
        return left_val, left_val
    if trig == 'right-panel-tabs':
        return right_val, right_val

    # Button clicks → compute next tab based on current_right
    if trig in ('nav-prev-scoping-tab', 'nav-prev-building-tab', 'nav-prev-results-tab'):
        target = _step(TAB_ORDER, current_right, -1)
        return target, target

    if trig in ('nav-next-scoping-tab', 'nav-next-building-tab', 'nav-next-results-tab'):
        target = _step(TAB_ORDER, current_right, +1)
        return target, target

    return no_update, no_update
# ---------- end router ----------

@app.callback(
    Output('select-dataset', 'options'),
    Input('select-dataset', 'value')
)
def populate_dataset_dropdown(_):
    return get_datasets()

@app.callback(
    [Output('output-data-upload', 'children'),
     Output('filter-column-dropdown', 'options'),
     Output('loaded-dataset', 'data')],
    Input('select-dataset', 'value')
)
def load_selected_dataset(dataset_name):
    if dataset_name:
        try:
            dataset = dataiku.Dataset(dataset_name)
            df = dataset.get_dataframe()

            if df.empty:
                return (html.Div(['The selected dataset is empty.'], style=subheader_style), [], None)

            # Check if "account_id" column is present
            if "account_id" not in df.columns:
                return (html.Div([f'Column key account_id in {dataset_name} is missing. Please prepare the data in the Dataiku flow and include the required column.'], style=div_message_style), [], None)

            options = [{"label": col, "value": col} for col in df.columns]

            return (html.Div([dash_table.DataTable(
                        id='output-data-upload-table',
                        data=df.to_dict('records'),
                        columns=[{'name': i, 'id': i} for i in df.columns],
                        page_size=15,
                        page_action='native',
                        sort_action='native',
                        filter_action='native',
                        page_current=0,
                        **table_styles
                    )]),
                    options,
                    df.to_dict('records')
            )
        except Exception as e:
            return (html.Div([f"Error loading dataset: {str(e)}"], style=div_message_style), [], None)

    return html.Div(['Please select a dataset.'], style=subheader_style), [], None

@app.callback(
    Output("filter-inputs", "children"),
    [Input("filter-column-dropdown", "value")],
    [State('select-dataset', 'value')]
)
def generate_filter_inputs(selected_columns, dataset_name):
    if dataset_name is None:
        return []
    dataset = dataiku.Dataset(dataset_name)
    df = dataset.get_dataframe()
    inputs = [generate_filter_input(df, col) for col in selected_columns if generate_filter_input(df, col)]
    return inputs

@app.callback(
    [Output("output-data-upload-table", "data"),
     Output("dataset-title", "children"),
     Output("filter-dictionary-store", "data")],
    Input("apply-filters", "n_clicks"),
    [State({"type": "filter-dropdown", "index": ALL}, "value"),
     State({"type": "filter-slider", "index": ALL}, "value"),
     State({"type": "filter-date-picker", "index": ALL}, "start_date"),
     State({"type": "filter-date-picker", "index": ALL}, "end_date"),
     State("filter-column-dropdown", "value"),
     State('loaded-dataset', 'data'),
     State('select-dataset', 'value')]
)
def apply_filters_on_button_click(apply_n_clicks, filter_values, filter_ranges, start_dates, end_dates, selected_columns, loaded_data, dataset_name):
    try:
        # Return empty if no dataset is selected
        if not dataset_name or loaded_data is None:
            return [], "", {}

        # Convert the stored data back to a DataFrame (full dataset)
        df = pd.DataFrame(loaded_data)

        # Handle cases where filters are reset or not applied (No Filter button clicked)
        ctx = dash.callback_context
        if not ctx.triggered:
            return df.to_dict("records"), f"{dataset_name} with {len(df)} records.", {}

        # Apply filtering only on button click (Apply Filters button clicked)
        if selected_columns:
            filtered_df, selected_columns_date, selected_columns_str, selected_columns_num = filter_dataframe(
                df, selected_columns, filter_values, filter_ranges, start_dates, end_dates
            )
        else:
            return df.to_dict("records"), f"{dataset_name} with {len(df)} records.", {}

        # Build filter dictionary to store filter information
        filter_dictionary = build_filter_dictionary(
            selected_columns_date, selected_columns_str, selected_columns_num,
            start_dates, end_dates, filter_values, filter_ranges
        )

        # Update the title with the number of records after filtering
        filtered_size = len(filtered_df)
        title = f"{dataset_name} with {filtered_size} records."

        return filtered_df.to_dict("records"), title, filter_dictionary

    except Exception as e:
        return [], " ", {}

@app.callback(
    Output('feature-selection', 'options'),
    [Input('clustering-method-store', 'data'), State('select-dataset', 'value')]
)
def update_feature_selection(method, dataset_name):
    if not dataset_name:
        return []
    dataset = dataiku.Dataset(dataset_name)
    df = dataset.get_dataframe()

    if method == "rule_based":
        numerical_features = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
        feature_options = [{"label": feature, "value": feature} for feature in numerical_features]
    else:
        feature_options = [{"label": feature, "value": feature} for feature in df.columns]

    # Add "Select All" option
    feature_options.insert(0, {"label": "Select All", "value": "select_all"})
    return feature_options

@app.callback(
    Output('feature-selection', 'value'),
    Input('feature-selection', 'value'),
    State('feature-selection', 'options')
)
def handle_select_all_feature_selection(selected_values, available_options):
    # If "Select All" is selected, return all available options (excluding "Select All" itself)
    if selected_values is None or available_options is None:
        return []
    if "select_all" in selected_values:
        return [option['value'] for option in available_options if option['value'] != 'select_all']
    return selected_values

@app.callback(
    [Output("kmeans-icon", "style"),
     Output("rule-based-icon", "style")],
    [Input("kmeans-icon", "n_clicks"),
     Input("rule-based-icon", "n_clicks")]
)
def update_clustering_method_visual(kmeans_clicks, rule_based_clicks):
    # Default styles for both methods
    kmeans_style = unselected_style
    rule_based_style = unselected_style

    ctx = dash.callback_context

    if not ctx.triggered:
        return kmeans_style, rule_based_style

    # Identify the triggered method and update the style
    clicked_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if clicked_id == "kmeans-icon":
        kmeans_style = selected_style
        return kmeans_style, rule_based_style

    elif clicked_id == "rule-based-icon":
        rule_based_style = selected_style
        return kmeans_style, rule_based_style

    return kmeans_style, rule_based_style

@app.callback(
    [Output("num-clusters-container", "style"),
     Output("apply-segmentation", "style"),
     Output("method-selection", "children"),
     Output("clustering-method-store", "data"),
     Output("rule-based-parameters", "style"),
     Output("feature-selection-container", "style"),
     Output("method-description", "children"),
     Output("method-changed-store", "data")
     ],
    [Input("kmeans-icon", "n_clicks"),
     Input("rule-based-icon", "n_clicks")]
)
def display_parameters(kmeans_clicks, rule_based_clicks):
    ctx = dash.callback_context

    # Default styles: All hidden if no method is selected
    visible_style = {'display': 'block'}
    hidden_style_local = {'display': 'none'}

    kmeans_description = "K-means is an unsupervised machine learning algorithm that partitions data into a predefined number of clusters (segments) based on the similarity between data points. It groups data points so that each point belongs to the cluster with the nearest centroid, which is the mean of the data points in that cluster."
    rule_based_description = "Rule-based method segments data based on predefined rules or thresholds, typically set by the user. In this approach, data are first normalized to a common scale and then divided into discrete intervals (bins) according to the number of segments (clusters) specified. Each data point is assigned to a bin based on its normalized value, resulting in a fixed number of user-defined segments."

    if not ctx.triggered:
        # If no clicks yet, keep everything hidden and show a default message
        return hidden_style_local, hidden_style_local, None, None, hidden_style_local, hidden_style_local, " ", True
    clicked_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if clicked_id == "kmeans-icon":
        # Show KMeans parameters and description
        return visible_style, {**button_style, 'display': 'block'}, None, "kmeans", hidden_style_local, visible_style, kmeans_description, True
    elif clicked_id == "rule-based-icon":
        # Show Rule-Based parameters and description
        return hidden_style_local, {**button_style, 'display': 'block'}, None, "rule_based", visible_style, visible_style, rule_based_description, True
    else:
        return hidden_style_local, hidden_style_local, None, None, hidden_style_local, hidden_style_local, " ", True

@app.callback(
    Output("weights-container", "children"),
    [Input("clustering-method-store", "data"),
     Input("feature-selection", "value")]
)
def toggle_feature_weights_input(method, selected_features):
    if method == "rule_based" and selected_features:
        max_feature_len = max(len(feature) for feature in selected_features)  # Calculate the max length of selected features
        inputs = []
        for feature in selected_features:
            inputs.append(
                html.Div([
                    html.Label(f"{feature}", style={"textAlign": "center", 'fontSize': '12px'}),
                    dcc.Input(id={'type': 'feature-weight', 'index': feature}, type="number", value=1,
                              style=enter_box_style)
                ], style={"margin-bottom": "20px", "textAlign": "left", "width": f"{max_feature_len + 2}ch"})
            )
        return inputs
    else:
        return []

@app.callback(
    [Output("table-container", "style"),
     Output("no-data-message", "children"),
     Output("no-data-message", "style"),
     Output("output-table", "data"),
     Output("output-table", "columns"),
     Output("feature-dropdown", "options"),
     Output("kmeans-pipeline-store", "data"),
     Output("rulebased-bin-bounds", "data"),
     Output("rulebased-weights", "data"),
     Output("message-div", "children"),
     Output("message-div", "style"),
     Output('apply-segmentation', 'children'),
     Output('apply-segmentation', 'disabled'),
     Output('progress-interval', 'disabled')],
    Output("segmentation-help-message", "style"),
    [Input("apply-segmentation", "n_clicks"),
     Input("progress-interval", "n_intervals"),
     Input("method-changed-store", "data")],
    [State("clustering-method-store", "data"),
     State("feature-selection", "value"),
     State("num-clusters", "value"),
     State("num-bins", "value"),
     State("output-data-upload-table", "data"),
     State({'type': 'feature-weight', 'index': ALL}, 'value')]
)
def apply_segmentation(n_clicks, intervals, method_changed, method, selected_features, num_clusters, num_bins, filtered_data, feature_weights):
    ctx = dash.callback_context

    # Default styles
    hidden_style_local = {'display': 'none'}
    visible_style = {}

    # If no click, show message
    if n_clicks is None:
        return hidden_style_local, "No segmentation applied yet.", visible_style, [], [], [], None, None, None, "", {"color": "black"}, "Run", False, True, {'display': 'none'}

    if not selected_features:
        return hidden_style_local, "Please select at least one feature.", visible_style, [], [], [], None, None, None, "", {"color": "black"}, "Run", False, True, {'display': 'none'}

    triggered_id = ctx.triggered[0]['prop_id'].split('.')[0] if ctx.triggered else None

    # If method changed, clear the table
    if triggered_id == "method-changed-store" and method_changed:
        return hidden_style_local, "Select Parameters and click RUN to generate new results.", visible_style, [], [], [], None, None, None, "", {"color": "black"}, "Run", False, True, {'display': 'none'}

    # Handle progress simulation
    if ctx.triggered and 'progress-interval.n_intervals' in ctx.triggered[0]['prop_id']:
        if intervals < 5:  # Simulate progress
            return no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, "Running...", True, False, {'display': 'none'}
        else:
            return no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, no_update, "Run", False, True, {'display': 'none'}

    # Start segmentation process after passing checks
    df_filtered = pd.DataFrame(filtered_data)
    numerical_columns = df_filtered.select_dtypes(include=['number']).columns.tolist()
    categorical_columns = df_filtered.select_dtypes(include=['object', 'category']).columns.tolist()
    columns_list = [col for col in (numerical_columns + categorical_columns) if col not in ['account_id']]

    cluster_name = 'cluster'
    kmeans_pipeline = None
    kmeans_pipeline_serialized = None
    rb_bounds = None
    rb_weights = None
    warning_message = ""

    if method == "kmeans":
        # KMeans clustering logic
        df_result, kmeans_pipeline, preprocessor, error_message = kmeans_clustering(df_filtered, selected_features, num_clusters, cluster_name)
        if df_result is None:
            return hidden_style_local, error_message, visible_style, [], [], [], None, None, None, error_message, {"color": "black", "font-weight": "bold"}, "Run", False, True, {'display': 'none'}

        df_result[cluster_name] = 'Segment_' + (df_result[cluster_name] + 1).astype(str)
        kmeans_pipeline_serialized = base64.b64encode(pickle.dumps(kmeans_pipeline)).decode('utf-8')
        feature_importance_df = feature_importance_rfclassifier(df_result, selected_features, preprocessor, cluster_name)

    elif method == "rule_based":
        # Rule-based clustering logic
        feature_weights_dict = {feature: feature_weights[idx] if feature_weights[idx] is not None else 1 for idx, feature in enumerate(selected_features)}
        df_result, rb_bounds, rb_weights, warning_message = rule_based(df_filtered, feature_weights_dict, selected_features, cluster_name, num_bins)
        importance_dict = feature_weights_dict
        feature_importance_df = pd.DataFrame(list(importance_dict.items()), columns=['Feature', 'Importance']).sort_values(by='Importance', ascending=False)

        if warning_message:
            return hidden_style_local, warning_message, visible_style, [], [], [], None, None, None, warning_message, {"color": "black", "font-weight": "bold"}, "Run", False, True, {'display': 'none'}

    # Reorder columns and prepare table data
    df_result = move_columns(df_result, 1, 'cluster')
    data = df_result.to_dict('records')
    columns = [{"name": col, "id": col} for col in df_result.columns]
    feature_options = [{"label": col, "value": col} for col in columns_list]

    # Return with the table visible, message hidden
    return visible_style, "", hidden_style_local, data, columns, feature_options, kmeans_pipeline_serialized, rb_bounds, rb_weights, "", {"color": "black", "font-weight": "bold"}, "Run", False, True, {**div_message_style, 'display': 'block', 'marginBottom': '15px'}

@app.callback(
    Output("mapping-container", "children"),
    [Input("output-table", "data")]
)
def generate_remap_inputs(output_data):
    if output_data:
        df = pd.DataFrame(output_data)
        cluster_column = [col for col in df.columns if 'cluster' in col][0]
        unique_clusters = df[cluster_column].unique()

        max_cluster_len = len(str('Remap Cluster Names'))
        inputs = []
        for cluster in unique_clusters:
            cluster_str = str(cluster)
            inputs.append(html.Div([
                html.Label(f"'{cluster_str}'", style={"textAlign": "center"}),
                dcc.Input(id={'type': 'remap-input', 'index': cluster_str}, type='text', value=cluster_str,
                          style=enter_box_style)
            ], style={"margin-bottom": "20px", "textAlign": "left", "width": f"{max_cluster_len + 2}ch"}))
        return inputs
    return []

@app.callback(
    Output('numerical-feature-multiselect', 'options'),
    [Input('output-table', 'data')]
)
def update_numerical_feature_multiselect_output(table_data):
    return update_numerical_feature_multiselect(table_data)

@app.callback(
    Output('numerical-feature-multiselect', 'value'),
    Input('numerical-feature-multiselect', 'value'),
    State('numerical-feature-multiselect', 'options')
)
def handle_select_all_numerical_feature_multiselect(selected_values, available_options):
    if selected_values is None or available_options is None:
        return []
    if "select_all" in selected_values:
        return [option['value'] for option in available_options if option['value'] != 'select_all']
    return selected_values

@app.callback(
    Output('cluster-histogram', 'figure'),
    [Input('output-table', 'data'),
     Input('remap-cluster-names-store', 'data')]
)
def update_cluster_pie_chart_output(table_data, remap_data):
    df = process_table_data_with_remap(table_data, remap_data)

    if df is None or df.empty:
        return {
            'data': [],
            'layout': {
                'xaxis': {'visible': False},
                'yaxis': {'visible': False},
                'annotations': [{'text': "No data selected yet.", 'xref': "paper", 'yref': "paper", 'showarrow': False, 'font': {'size': 12}}]
            }
        }

    return update_cluster_pie_chart(df.to_dict('records'))

def create_empty_figure(message="No data selected yet."):
    return {
        'data': [],
        'layout': {
            'xaxis': {'visible': False},
            'yaxis': {'visible': False},
            'annotations': [{'text': message, 'xref': "paper", 'yref': "paper", 'showarrow': False, 'font': {'size': 12}}]
        }
    }

@app.callback(
    Output('importance-histogram', 'figure'),
    [Input('output-table', 'data'),
     Input('kmeans-pipeline-store', 'data'),
     Input('rulebased-weights', 'data'),
     Input('clustering-method-store', 'data'),
     Input('feature-selection', 'value')]
)
def update_importance_histogram(table_data, kmeans_pipeline_serialized, rb_weights, method, selected_features):
    # Return empty figure when no data is available
    if not table_data or len(table_data) == 0 or not method or not selected_features:
        return create_empty_figure()

    # When data exists, create the importance histogram
    df = pd.DataFrame(table_data)

    if method == "kmeans" and kmeans_pipeline_serialized:
        # For KMeans, decode the pipeline to get the preprocessor
        try:
            kmeans_pipeline = pickle.loads(base64.b64decode(kmeans_pipeline_serialized.encode('utf-8')))
            preprocessor = kmeans_pipeline.named_steps.get('preprocessor', None)
            feature_importance_df = feature_importance_rfclassifier(df, selected_features, preprocessor, 'cluster')
        except Exception:
            return create_empty_figure("Error processing importance data")

    elif method == "rule_based" and rb_weights:
        try:
            importance_dict = rb_weights
            feature_importance_df = pd.DataFrame(list(importance_dict.items()),
                                                 columns=['Feature', 'Importance']).sort_values(by='Importance', ascending=False)
        except Exception:
            return create_empty_figure("Error processing importance data")
    else:
        return create_empty_figure()

    feature_importance_df = feature_importance_df.sort_values('Importance', ascending=False).head(10)
    return histogram_figure(feature_importance_df, "Importance", "Feature",
                            {'Importance': 'Importance Score', 'Feature': 'Features'})

@app.callback(
    Output("feature-histograms", "children"),
    [Input("feature-dropdown", "value"),
     Input('remap-cluster-names-store', 'data')],
    [State("output-table", "data")]
)
def update_histograms_output(selected_feature, remap_data, table_data):
    if not selected_feature:
        return {}  # Return empty if no feature selected

    df = process_table_data_with_remap(table_data, remap_data)
    if df is None:
        return {}  # Return empty if no valid data

    return update_mixed_graph(selected_feature, df.to_dict('records'))

@app.callback(
    Output('mean-heatmap', 'children'),
    [Input('numerical-feature-multiselect', 'value'),
     Input('remap-cluster-names-store', 'data')],
    [State('output-table', 'data')]
)
def update_heatmap(selected_numerical_features, remap_data, table_data):
    if not selected_numerical_features:
        return {}  # Return empty if no features selected

    df = process_table_data_with_remap(table_data, remap_data)
    if df is None:
        return {}  # Return empty if no valid data or cluster column

    return create_average_heatmap(df, selected_numerical_features, cluster_column_name='cluster')

@app.callback(
    [Output("remap-cluster-names-store", "data"),
     Output("remap-message", "children")],
    [Input("remap-button", "n_clicks")],
    [State("output-table", "data"),
     State({'type': 'remap-input', 'index': ALL}, 'value')]
)
def remap_cluster_names(n_clicks, output_data, new_cluster_names):
    try:
        if n_clicks > 0 and output_data:
            df = pd.DataFrame(output_data)
            cluster_column = [col for col in df.columns if 'cluster' in col]
            if not cluster_column:
                raise ValueError("No cluster column found in the provided data.")

            cluster_column = cluster_column[0]
            unique_clusters = df[cluster_column].unique()
            unique_clusters_str = [str(cluster) for cluster in unique_clusters]

            mapping_dict = dict(zip(unique_clusters_str, new_cluster_names))
            df[cluster_column] = df[cluster_column].astype(str).map(mapping_dict)

            return {'data': df.to_dict('records'), 'mapping_dict': mapping_dict, 'message': ''}, html.P("Successful segments remap.", style=div_message_style)
        else:
            return {'data': output_data, 'mapping_dict': {}, 'message': ''}, ""
    except Exception as e:
        return {'error': str(e), 'data': output_data, 'mapping_dict': {}, 'message': ''}, ""

@app.callback(
    [Output("download-dataframe-csv", "data"),
     Output("error-message", "children")],
    Input("export-icon", "n_clicks"),
    [State('session-name', 'value'),
     State("remap-cluster-names-store", "data")],
    prevent_initial_call=True,
)
def export_data_as_csv(n_clicks, session_name, remap_data):
    if not session_name:
        return None, html.P("Please enter a session name.", style=div_message_style)

    if n_clicks > 0 and remap_data is not None:
        remapped_data = remap_data['data']
        df = pd.DataFrame(remapped_data)
        return dcc.send_data_frame(df.to_csv, session_name + "_results.csv"), ""
    return None, ""

@app.callback(
    Output("description-dictionary", "data"),
    Input("metadata-description", "value")
)
def store_description(description):
    if description:
        return {"description": description}
    return {}


@app.callback(
    Output('save-output', 'children'),
    [Input('save-icon', 'n_clicks')],
    [State('remap-cluster-names-store', 'data'),
     State('session-name', 'value'),
     State('select-dataset', 'value'),
     State('clustering-method-store', 'data'),
     State('feature-selection', 'value'),
     State('num-clusters', 'value'),
     State('filter-dictionary-store', 'data'),
     State('kmeans-pipeline-store', 'data'),
     State('rulebased-bin-bounds', 'data'),
     State('rulebased-weights', 'data'),
     State('description-dictionary', 'data')]
)
def save_data(save_clicks, remap_data, session_name, original_dataset, method,
              selected_features, num_clusters, filter_dictionary,
              kmeans_pipeline_serialized, rb_bounds, rb_weights, description_dict):

    ctx = dash.callback_context
    if not ctx.triggered:
        return ""

    button_id = ctx.triggered[0]['prop_id'].split('.')[0]
    if button_id != 'save-icon':
        return ""

    # --- Existing validations (unchanged) ---
    if remap_data is None:
        return html.P("Error: No cluster data available. Please run segmentation first.", style=div_message_style)

    if 'mapping_dict' not in remap_data or not remap_data['mapping_dict']:
        return html.P("Error: You must remap the cluster names before saving the data.", style=div_message_style)

    if not session_name:
        return html.P("Error: Please add a session name.", style=div_message_style)

    if not session_name.isalnum() and not all(c.isalnum() or c == '_' for c in session_name):
        return html.P("Error: Session name can only contain letters, numbers, and underscores.", style=div_message_style)

    cluster_session_name = session_name

    try:
        if 'session_name' in metadata_df.columns and cluster_session_name in metadata_df["session_name"].values:
            return html.P("Error: Session name must be unique. This session name already exists.", style=div_message_style)
    except NameError:
        return html.P("Error: Metadata database not available. Please contact administrator.", style=div_message_style)

    description = description_dict.get("description", "") if description_dict else ""
    if not description:
        return html.P("Error: You must provide a description before saving the data.", style=div_message_style)

    if method == "kmeans" and not kmeans_pipeline_serialized:
        return html.P("Error: KMeans model data is missing. Please re-run the segmentation.", style=div_message_style)

    if method == "rule_based" and (not rb_bounds or not rb_weights):
        return html.P("Error: Rule-based parameters are missing. Please re-run the segmentation.", style=div_message_style)

    try:
        remapped_data = remap_data['data']
        mapping_dict = remap_data['mapping_dict']

        if not remapped_data or len(remapped_data) == 0:
            return html.P("Error: No data to save. The segmentation result appears to be empty.", style=div_message_style)

        cluster_name = session_name
        version = 0
        status = 'active'

        # Create metadata record (unchanged)
        try:
            metadata_record = create_metadata_record(
                cluster_name, version, status, original_dataset, method,
                description, filter_dictionary, selected_features, mapping_dict
            )
        except Exception as record_error:
            return html.P(f"Error creating metadata record: {str(record_error)}", style=div_message_style)

        # ---- Perform the actual save (unchanged) ----
        save_result_component = save_data_to_dataiku(
            cluster_name, remapped_data, metadata_record, method,
            kmeans_pipeline_serialized, rb_bounds, rb_weights
        )

        restart_app_silently(seg_manager_webapp_id)
        restart_app_silently(multi_seg_explorer_webapp_id)
        return(save_result_component)

    except Exception as e:
        return html.P(f"Error preparing data for save: {str(e)}", style=div_message_style)

@app.callback(
    Output('refresh-trigger', 'children'),
    Input('clear-icon', 'n_clicks'),
    prevent_initial_call=True
)
def trigger_clear(n_clicks):
    if n_clicks:
        return "REFRESH"
    return dash.no_update

app.clientside_callback(
    """
    function(trigger) {
        if (trigger === "REFRESH") {
            // Wait 0.5 seconds to show the success message before refreshing
            setTimeout(function() {
                window.location.reload();
            }, 500);
        }
        return '';
    }
    """,
    Output('refresh-trigger', 'style'),
    Input('refresh-trigger', 'children'),
    prevent_initial_call=True
)