# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import io
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import json

folder = dataiku.Folder("jGYaG5fn")
df = dataiku.Dataset("preferences").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def topological_sort(d):
    """
    Performs topological sorting on a dictionary of approach comparisons.
    
    Args:
    d (dict): A dictionary where d["approach0"]["approach1"] = 0 (resp. 1) means approach0 is better (resp. worse) than approach1.
    
    Returns:
    list: A list of approaches in topological order.
    """
    # Build a graph from the comparisons
    graph = defaultdict(list)
    in_degree = {}
    for approach0 in d:
        if approach0 not in in_degree:
            in_degree[approach0] = 0
        for approach1 in d[approach0]:
            if approach1 not in in_degree:
                in_degree[approach1] = 0
            if d[approach0][approach1] == 0:
                graph[approach0].append(approach1)
                in_degree[approach1] += 1
            elif d[approach0][approach1] == 1:
                graph[approach1].append(approach0)
                in_degree[approach0] += 1

    # Perform topological sort
    queue = [node for node in in_degree if in_degree[node] == 0]
    sorted_approaches = []
    while queue:
        approach = queue.pop(0)
        sorted_approaches.append(approach)
        for neighbor in graph[approach]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    if len(sorted_approaches) != len(in_degree):
        raise ValueError("The input dictionary contains a cycle, so a valid ordering cannot be determined.")

    return sorted_approaches

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# comparisons[i][approach1][approach2] = 1 iff approach1 is evaluated as better than approach2 on question i
comparisons = {q: {} for q in set(df["question_id"])}
for i in df.index:
    q = df.at[i, "question_id"]
    current, new = df.at[i, "current_approach"], df.at[i, "new_approach"]
    result = df.at[i, "result"]
    if new not in comparisons[q]:
        comparisons[q][new] = {current: result}
    else:
        comparisons[q][new][current] = result

# ranking[i] is the ordering of the various approaches as evaluated for question i
ranking = {q: topological_sort(comparisons[q]) for q in comparisons}

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for q in ranking:
    approaches = ranking[q]
    break

wins = {m: defaultdict(int) for m in approaches}

for q in ranking:
    for i in range(len(approaches) - 1):
        for j in range(i + 1, len(approaches)):
            wins[ranking[q][j]][ranking[q][i]] += 1

for a1 in wins:
    for a2 in wins[a1]:
        wins[a1][a2] /= len(ranking)

sorted_approaches = sorted(approaches, key=lambda s: sum(wins[s][k] for k in wins[s]))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
win_matrix = np.zeros((len(approaches), len(approaches)))

for i in range(len(approaches)):
    for j in range(len(approaches)):
        if i != j:
            win_matrix[i, j] = wins[sorted_approaches[i]][sorted_approaches[j]]
    win_matrix[i, i] = None

_ = sns.heatmap(
    win_matrix, annot=True, xticklabels=sorted_approaches, yticklabels=sorted_approaches
).set(title="Win Rate")

with io.BytesIO() as buf:
    plt.savefig(buf)
    folder.upload_data("heatmap.png", buf.getvalue())

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
ranking_df = pd.DataFrame(
    [(q, json.dumps(ranking[q])) for q in ranking], columns=["question_id", "ranking"]
)
dataiku.Dataset("ranking").write_with_schema(ranking_df)