Skip to content

Commit

Permalink
Update plotting by using Plotly
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li authored and menamerai committed Feb 21, 2024
1 parent 94c6b49 commit 2feaeef
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 74 deletions.
75 changes: 75 additions & 0 deletions notebooks/per_token_plot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6825ba6543ba45438b18cf0949576020",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('nouns', 'verbs', 'prepositions', 'adjectives'…"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from collections import defaultdict\n",
"import math\n",
"import random\n",
"import numpy as np\n",
"\n",
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
"\n",
"\n",
"random.seed(0)\n",
"\n",
"# generate mock data\n",
"model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m']\n",
"categories = ['nouns', 'verbs', 'prepositions', 'adjectives']\n",
"entries = [200, 100, 150, 300]\n",
"performance_data = defaultdict()\n",
"for i, model in enumerate(model_names):\n",
" performance_data[model] = defaultdict()\n",
" for cat in categories:\n",
" x = [math.log2(random.random()) for _ in range(entries[i])]\n",
" means = np.mean(x)\n",
" err_low = means - np.percentile(x, 25)\n",
" err_hi = np.percentile(x, 75) - means\n",
" performance_data[model][cat] = (-means, err_low, err_hi)\n",
"\n",
"\n",
"visualize_per_token_category(performance_data)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
47 changes: 0 additions & 47 deletions src/delphi/dataset/mock_per_token_performance.py

This file was deleted.

67 changes: 40 additions & 27 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
import ipywidgets as widgets
import matplotlib.pyplot as plt
import ipywidgets
import numpy as np
from ipywidgets import interact
import plotly.graph_objects as go
from beartype.typing import Dict


def visualize_per_token_category(input):
def visualize_per_token_category(input: Dict[str, Dict[str, tuple]]) -> ipywidgets.VBox:
model_names = list(input.keys())
categories = list(list(input.values())[0].keys())
categories = list(input[model_names[0]].keys())
category = categories[0]

def _f(category):
def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
means = np.mean(x, axis=0)
median = np.median(x, axis=0)
q1 = np.quantile(x, 0.25, axis=0)
q3 = np.quantile(x, 0.75, axis=0)

ax = plt.gca()
ax.set_ylim([-5, 5]) # TODO

plt.plot(model_names, means)
plt.errorbar(model_names, median, yerr=[median - q1, q3 - median], fmt="o")

interact(
_f,
category=widgets.Dropdown(
options=categories,
placeholder="",
description="Token Category:",
disabled=False,
means, err_lo, err_hi = x[0], x[1], x[2]
return means, err_lo, err_hi

means, err_low, err_hi = get_plot_values(category)
g = go.FigureWidget(
data=go.Scatter(
x=model_names,
y=means,
error_y=dict(
type="data",
symmetric=False,
array=err_hi,
arrayminus=err_low,
color="purple",
),
),
layout=go.Layout(yaxis=dict(title="Loss")),
)

selected_category = ipywidgets.Dropdown(
options=categories,
placeholder="",
description="Token Category:",
disabled=False,
)

def response(change):
if selected_category.value:
means, err_lo, err_hi = get_plot_values(selected_category.value)
with g.batch_update():
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo

selected_category.observe(response, names="value")

# Usage:
# from dataset.mock_per_token_performance import performance_datas
# visualize_per_token_category(performance_data)
return ipywidgets.VBox([selected_category, g])

0 comments on commit 2feaeef

Please sign in to comment.