Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end evaluation demo #32

Merged
merged 38 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
67a6a11
First version of visualization using dummy data
jaidhyani Feb 14, 2024
fc781a8
Add plotly library to requirements.txt
jaidhyani Feb 14, 2024
c9a95b6
wip
jaidhyani Feb 15, 2024
ed105bc
Add token_map.pkl file
jaidhyani Feb 16, 2024
bb9c613
Create data dir if it doesn't exist in map_tokens.py
jaidhyani Feb 16, 2024
7a0348f
script + output aggregating token/model logprobs
jaidhyani Feb 16, 2024
3c31782
README for data files
jaidhyani Feb 16, 2024
8692a3c
token groups hack
jaidhyani Feb 16, 2024
2c27923
useful constants
jaidhyani Feb 16, 2024
fb1bb2f
update requirements for spacy support
jaidhyani Feb 16, 2024
1d4ed64
Add GenericPreTrainedTransformer Union type for convenience
jaidhyani Feb 16, 2024
7074890
full token_model_stats.csv
jaidhyani Feb 16, 2024
f30d91a
wip
jaidhyani Feb 19, 2024
27ad021
wip
jaidhyani Feb 19, 2024
74514fa
experiments with different ways to aggregate logprops
jaidhyani Feb 23, 2024
1992c52
checkpoint before cleanup
jaidhyani Feb 23, 2024
787e3bf
post-cleanup checkpoint
jaidhyani Feb 23, 2024
08aee12
add hovertext
jaidhyani Feb 23, 2024
74f68a3
add model_group_stats pkl (tmp)
jaidhyani Feb 23, 2024
f047c2f
Delete gather_token_model_stats.py script
jaidhyani Feb 23, 2024
f70b60c
delete more unused scripts
jaidhyani Feb 23, 2024
bcdc49b
Remove hack_token_label.py
jaidhyani Feb 23, 2024
7683ee9
Remove unused constants in delphi/eval/constants.py
jaidhyani Feb 23, 2024
d304942
Add calc_model_group_stats function to calculate useful stats for vis…
jaidhyani Feb 23, 2024
d2431c2
Add load_logprob_dataset and load_logprob_datasets functions
jaidhyani Feb 23, 2024
88001d2
Cleanup on end2end
jaidhyani Feb 23, 2024
1fc1d1d
Delete unused data files
jaidhyani Feb 23, 2024
5b0346a
updata data README
jaidhyani Feb 23, 2024
e9281e0
TODO: Move README.md to delphi/static
jaidhyani Feb 23, 2024
f17c77b
Refactor calc_model_group_stats function
jaidhyani Feb 23, 2024
10dc4a0
remove unusued GenericPretrainedTransformer
jaidhyani Feb 23, 2024
b36eeaf
formatting issues
jaidhyani Feb 23, 2024
440ea0e
formatting
jaidhyani Feb 23, 2024
ff64510
formatting fixes
jaidhyani Feb 23, 2024
02138df
use static assets dir
jaidhyani Feb 27, 2024
007d96d
CR
jaidhyani Feb 27, 2024
81795cf
Merge branch 'main' into 13-first-end2end-eval-demo
jaidhyani Feb 27, 2024
6569b71
Merge branch 'main' into 13-first-end2end-eval-demo
jaidhyani Feb 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
},
"python.analysis.typeCheckingMode": "basic",
"isort.args": [
"--profile black"
"--profile",
"black"
],
"black-formatter.importStrategy": "fromEnvironment",
}
10 changes: 10 additions & 0 deletions data/README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, now that the static dir PR is in

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# TODO: move this to delphi/static
# Static Data Files


## `token_map.pkl`
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs.

## `model_group_stats.pkl`
useful statistics for data visualization of (model, tokengroup) pairs; dict of (model, tokengroup) to dict of (str, float):
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
Binary file added data/model_group_stats.pkl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed and added to static

Binary file not shown.
Binary file added data/token_map.pkl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be removed and added to static

Binary file not shown.
131 changes: 131 additions & 0 deletions notebooks/end2end_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from typing import cast\n",
"import pickle\n",
"from collections import defaultdict\n",
"\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"from delphi.eval import utils\n",
"from delphi.eval import constants\n",
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
"\n",
"# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
"from delphi.eval.token_labelling import TOKEN_LABELS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# load data\n",
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
"\n",
"# TODO: convert to use static paths\n",
"# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
"# token_groups = pickle.load(f)\n",
"# model_group_stats = calc_model_group_stats(\n",
"# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n",
"# )\n",
"with open(\"../data/model_group_stats.pkl\", \"rb\") as f:\n",
" model_group_stats = pickle.load(f)\n",
"\n",
"logprob_datasets = utils.load_logprob_datasets(\"validation\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9323c96bf6834ceb91f751f45512ea85",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"performance_data = defaultdict(dict)\n",
"for model in constants.LLAMA2_MODELS:\n",
" for token_group_desc in TOKEN_LABELS:\n",
" if (model, token_group_desc) not in model_group_stats:\n",
" continue\n",
" stats = model_group_stats[(model, token_group_desc)]\n",
" performance_data[model][token_group_desc] = (\n",
" -stats[\"median\"],\n",
" -stats[\"75th\"],\n",
" -stats[\"25th\"],\n",
" )\n",
"\n",
"visualize_per_token_category(\n",
" performance_data,\n",
" log_scale=True,\n",
" bg_color=\"LightGrey\",\n",
" line_color=\"Red\",\n",
" marker_color=\"Orange\",\n",
" bar_color=\"Green\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tinyevals",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ isort==5.13.2
spacy==3.7.2
chardet==5.2.0
sentencepiece==0.1.99
protobuf==4.25.2
protobuf==4.25.2
plotly==5.18.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, meant to reply to this but the vscode inline thingy is kind of buggy and it ended up as a separate comment. Anyway:

spacy-transformers is a dependency I needed to get the token labeling code to work that I think just got missed by accident earlier (it's a hidden dependency that only shows up at runtime).

spacy-transformers==1.3.4
1 change: 1 addition & 0 deletions scripts/map_tokens.py
100644 → 100755
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import argparse
import os
import pickle

from delphi.constants import STATIC_ASSETS_DIR
Expand Down
54 changes: 54 additions & 0 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np


def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprob_datasets: dict[str, list[list[float]]],
token_groups: dict[int, dict[str, bool]],
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved
"""
For each (model, token group) pair, calculate useful stats (for visualization)

args:
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. constants.LLAMA2_MODELS
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]

returns: a dict of (model, token group) pairs to a dict of stats,
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}

Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
but it's better to be explicit

stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in logprob_datasets:
group_logprobs = {}
print(f"Processing model {model}")
dataset = logprob_datasets[model]
for ix_doc_lp, document_lps in enumerate(dataset):
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
for token_group_desc in token_labels:
if token_groups[token][token_group_desc]:
if token_group_desc not in group_logprobs:
group_logprobs[token_group_desc] = []
group_logprobs[token_group_desc].append(logprob)
for token_group_desc in token_labels:
if token_group_desc in group_logprobs:
model_group_stats[(model, token_group_desc)] = {
"mean": np.mean(group_logprobs[token_group_desc]),
"median": np.median(group_logprobs[token_group_desc]),
"min": np.min(group_logprobs[token_group_desc]),
"max": np.max(group_logprobs[token_group_desc]),
"25th": np.percentile(group_logprobs[token_group_desc], 25),
"75th": np.percentile(group_logprobs[token_group_desc], 75),
}
return model_group_stats
14 changes: 14 additions & 0 deletions src/delphi/eval/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
corpus_dataset = "delphi-suite/tinystories-v2-clean"
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"

LLAMA2_MODELS = [
"delphi-llama2-100k",
"delphi-llama2-200k",
"delphi-llama2-400k",
"delphi-llama2-800k",
"delphi-llama2-1.6m",
"delphi-llama2-3.2m",
"delphi-llama2-6.4m",
"delphi-llama2-12.8m",
"delphi-llama2-25.6m",
]
15 changes: 15 additions & 0 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from delphi.eval import constants


def get_all_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
Expand Down Expand Up @@ -87,3 +89,16 @@ def tokenize(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def load_logprob_dataset(model: str):
return cast(
Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs")
)
jaidhyani marked this conversation as resolved.
Show resolved Hide resolved


def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]:
return {
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried how much we have to use casting with Dataset objects :/

Copy link
Collaborator Author

@jaidhyani jaidhyani Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the dataset library clearly wasn't designed with type-guarantees in mind. Building a library of functions that take care of the casting and other stuff we don't want to worry about in our day-to-day might be a good idea. Ticket would be something like "identify repetitive casts in the codebase and replace them with library functions".

for model in constants.LLAMA2_MODELS
}
6 changes: 6 additions & 0 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def visualize_per_token_category(
categories = list(input[model_names[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
means, err_lo, err_hi = x[0], x[1], x[2]
Expand All @@ -32,6 +35,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
size=15,
line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
),
hovertext=get_hovertexts(means, err_low, err_hi),
hoverinfo="text+x",
),
layout=go.Layout(
yaxis=dict(
Expand All @@ -55,6 +60,7 @@ def response(change):
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo
g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)

selected_category.observe(response, names="value")

Expand Down
Loading