from graphviz import Digraph
import numpy as np
import pandas as pd
from webapp.formatting import human_format, formatted_human_readable
import bisect

border_color_frequency = '#48A8E9'
color_scale_frequency = ['#7DBEEA', '#8BC5EC', '#9ACCEF', '#A8D4F1', '#B7DBF3', '#C5E2F6', '#D4E9F8', '#E2F1FA', '#F1F8FD'][::-1]
border_color_time = '#8C62B5'
color_scale_time = ['#A47BC7', '#AE8ACD', '#B898D3', '#C2A7DA', '#CCB6E0', '#D7C4E6', '#E1D3EC', '#EBE2F3', '#F5F0F9'][::-1]

def create_graph(workflow_mined, start_end, activity_count, transition='frequency'):
    """
    Create a graph from the workflow mined data.

    The graph is created using the Graphviz library and is returned as a source string.
    The graph is colored based on the frequency or time of the activities.
    The nodes are colored based on the frequency of the activities, and the edges are colored based on the weight of the transitions.
    The graph is created in a way that the nodes are represented as boxes and the edges are represented as arrows.

    Args:
        workflow_mined (pd.DataFrame): DataFrame containing the mined workflow data with columns 'source', 'target', and 'weight'.
        start_end (pd.DataFrame): DataFrame containing the start and end steps with columns 'start_end' and 'step'.
        activity_count (pd.DataFrame): DataFrame containing the activity counts with columns 'step' and 'frequency'.
        transition (str, optional): Type of transition to visualize, either 'frequency or 'time'. Defaults to 'frequency'.

    Returns:
        tuple: A tuple containing:
            - dot.source (str): The source code of the generated Graphviz dot file.
            - unique_nodes (np.ndarray): Array of unique nodes in the graph.
            - max_count (int or str): The maximum frequency count or formatted human-readable maximum count if transition is "time".
    """
    if transition == 'frequency':
        color_scale = color_scale_frequency
        border_color = border_color_frequency
    else:
        color_scale = color_scale_time
        border_color = border_color_time
    starts = set(start_end[start_end['start_end']=='start']['step'])
    ends = set(start_end[start_end['start_end']=='end']['step'])
    
    max_weight = workflow_mined['weight'].max()
    min_weight = workflow_mined['weight'].min()
    max_count = activity_count['frequency'].max()
    max_count = max(max_count, 1) # avoid 0 division
    
    unique_nodes = pd.concat([workflow_mined['source'], workflow_mined['target']]).unique()
    unique_nodes_dict = {unique_nodes[i]: str(i) for i in range(len(unique_nodes))}
    dot = Digraph()
    
    for node in activity_count.sort_values('frequency')['step']:
        if node in starts:
            shape = 'box'
            fontcolor = 'black'
            bordercolor = 'black'
            color = 'white'
            style = 'solid'
        elif node in ends:
            shape = 'box'
            fontcolor = 'black'
            bordercolor = 'black'
            color = 'white'
            style = 'solid'
        else:
            count = activity_count[activity_count['step']==node]['frequency'].iloc[0]
            color = color_scale[int((count-1)*len(color_scale)/(max_count))]
            shape = 'box'
            fontcolor = 'black'
            bordercolor = border_color
            style = 'filled'
        if transition == 'frequency':
            name = node
        elif node == 'START' or node == 'END' or node == 'SINK':
            name = node
        else:
            name = node + ' [' + formatted_human_readable(int(count)) + ']'
        dot.node(unique_nodes_dict[node], name, color=bordercolor, shape=shape, style=style, 
                     fillcolor=color, fontname='Helvetica Neue', fontcolor=fontcolor, penwidth='2',
                tooltip=' ')
    
    min_max_edge = workflow_mined.agg({'weight': ['min', 'max']})['weight'].tolist()
    arrow_width_thresholds = [min_max_edge[0] + i/5 * (min_max_edge[1] - min_max_edge[0]) for i in range(5)]
    
    for index, edge in workflow_mined.iterrows():
        color = 'black'
        if transition == 'frequency':
            dot.edge(unique_nodes_dict[edge['source']], unique_nodes_dict[edge['target']], color=color, 
                     label='  ' + human_format(int(edge['weight'])), fontname='Helvetica Neue', 
                     penwidth=str(bisect.bisect(arrow_width_thresholds, edge['weight'])), tooltip=' ')#, labeltooltip=' ', tailtooltip=' ')
        elif transition == 'time':
            if edge['source'] == 'START' or edge['target'] == 'END':
                label = ""
            else:
                label = '  ' + formatted_human_readable(edge['weight'])
            dot.edge(unique_nodes_dict[edge['source']], unique_nodes_dict[edge['target']], color=color, 
                     label=label, fontname='Helvetica Neue', 
                     penwidth=str(bisect.bisect(arrow_width_thresholds, edge['weight'])), tooltip=' ')#, labeltooltip=' ', tailtooltip=' ')
    
    if transition == 'time':
        max_count = formatted_human_readable(int(max_count))
    
    return dot.source, unique_nodes, max_count