Skip to content

Commit

Permalink
Add visualization function for per token model comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li committed Feb 15, 2024
1 parent 714e90b commit 10e557b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
File renamed without changes.
37 changes: 37 additions & 0 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact


def visualize_per_token_category(input):
model_names = list(input.keys())
categories = list(list(input.values())[0].keys())

def _f(category):
x = np.array([input[name][category] for name in model_names]).T
means = np.mean(x, axis=0)
median = np.median(x, axis=0)
q1 = np.quantile(x, 0.25, axis=0)
q3 = np.quantile(x, 0.75, axis=0)

ax = plt.gca()
ax.set_ylim([-5, 5]) # TODO

plt.plot(model_names, means)
plt.errorbar(model_names, median, yerr=[median - q1, q3 - median], fmt="o")

interact(
_f,
category=widgets.Dropdown(
options=categories,
placeholder="",
description="Token Category:",
disabled=False,
),
)


# Usage:
# from dataset.mock_per_token_performance import performance_datas
# visualize_per_token_category(performance_data)

0 comments on commit 10e557b

Please sign in to comment.