-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
End-to-end evaluation demo #32
Changes from 34 commits
67a6a11
fc781a8
c9a95b6
ed105bc
bb9c613
7a0348f
3c31782
8692a3c
2c27923
fb1bb2f
1d4ed64
7074890
f30d91a
27ad021
74514fa
1992c52
787e3bf
08aee12
74f68a3
f047c2f
f70b60c
bcdc49b
7683ee9
d304942
d2431c2
88001d2
1fc1d1d
5b0346a
e9281e0
f17c77b
10dc4a0
b36eeaf
440ea0e
ff64510
02138df
007d96d
81795cf
6569b71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# TODO: move this to delphi/static | ||
# Static Data Files | ||
|
||
|
||
## `token_map.pkl` | ||
jaidhyani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs. | ||
|
||
## `model_group_stats.pkl` | ||
useful statistics for data visualization of (model, tokengroup) pairs; dict of (model, tokengroup) to dict of (str, float): | ||
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be removed and added to static |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be removed and added to static |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import cast\n", | ||
"import pickle\n", | ||
"from collections import defaultdict\n", | ||
"\n", | ||
"from datasets import load_dataset, Dataset\n", | ||
"\n", | ||
"from delphi.eval import utils\n", | ||
"from delphi.eval import constants\n", | ||
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n", | ||
"\n", | ||
"# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n", | ||
"from delphi.eval.token_labelling import TOKEN_LABELS" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# load data\n", | ||
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n", | ||
"\n", | ||
"# TODO: convert to use static paths\n", | ||
"# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n", | ||
"# token_groups = pickle.load(f)\n", | ||
"# model_group_stats = calc_model_group_stats(\n", | ||
"# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n", | ||
"# )\n", | ||
"with open(\"../data/model_group_stats.pkl\", \"rb\") as f:\n", | ||
" model_group_stats = pickle.load(f)\n", | ||
"\n", | ||
"logprob_datasets = utils.load_logprob_datasets(\"validation\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Visualization" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "9323c96bf6834ceb91f751f45512ea85", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"performance_data = defaultdict(dict)\n", | ||
"for model in constants.LLAMA2_MODELS:\n", | ||
" for token_group_desc in TOKEN_LABELS:\n", | ||
" if (model, token_group_desc) not in model_group_stats:\n", | ||
" continue\n", | ||
" stats = model_group_stats[(model, token_group_desc)]\n", | ||
" performance_data[model][token_group_desc] = (\n", | ||
" -stats[\"median\"],\n", | ||
" -stats[\"75th\"],\n", | ||
" -stats[\"25th\"],\n", | ||
" )\n", | ||
"\n", | ||
"visualize_per_token_category(\n", | ||
" performance_data,\n", | ||
" log_scale=True,\n", | ||
" bg_color=\"LightGrey\",\n", | ||
" line_color=\"Red\",\n", | ||
" marker_color=\"Orange\",\n", | ||
" bar_color=\"Green\",\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "tinyevals", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,6 @@ isort==5.13.2 | |
spacy==3.7.2 | ||
chardet==5.2.0 | ||
sentencepiece==0.1.99 | ||
protobuf==4.25.2 | ||
protobuf==4.25.2 | ||
plotly==5.18.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, meant to reply to this but the vscode inline thingy is kind of buggy and it ended up as a separate comment. Anyway: spacy-transformers is a dependency I needed to get the token labeling code to work that I think just got missed by accident earlier (it's a hidden dependency that only shows up at runtime). |
||
spacy-transformers==1.3.4 |
jaidhyani marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import os | ||
import pickle | ||
|
||
from delphi.constants import STATIC_ASSETS_DIR | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
|
||
|
||
def calc_model_group_stats( | ||
tokenized_corpus_dataset: list, | ||
logprob_datasets: dict[str, list[list[float]]], | ||
token_groups: dict[int, dict[str, bool]], | ||
jaidhyani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
token_labels: list[str], | ||
) -> dict[tuple[str, str], dict[str, float]]: | ||
jaidhyani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
For each (model, token group) pair, calculate useful stats (for visualization) | ||
|
||
args: | ||
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] | ||
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} | ||
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...} | ||
- models: a list of model names, e.g. constants.LLAMA2_MODELS | ||
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...] | ||
|
||
returns: a dict of (model, token group) pairs to a dict of stats, | ||
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} | ||
|
||
Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, | ||
but it's better to be explicit | ||
|
||
stats calculated: mean, median, min, max, 25th percentile, 75th percentile | ||
""" | ||
model_group_stats = {} | ||
for model in logprob_datasets: | ||
group_logprobs = {} | ||
print(f"Processing model {model}") | ||
dataset = logprob_datasets[model] | ||
for ix_doc_lp, document_lps in enumerate(dataset): | ||
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] | ||
for ix_token, token in enumerate(tokens): | ||
if ix_token == 0: # skip the first token, which isn't predicted | ||
continue | ||
logprob = document_lps[ix_token] | ||
for token_group_desc in token_labels: | ||
if token_groups[token][token_group_desc]: | ||
if token_group_desc not in group_logprobs: | ||
group_logprobs[token_group_desc] = [] | ||
group_logprobs[token_group_desc].append(logprob) | ||
for token_group_desc in token_labels: | ||
if token_group_desc in group_logprobs: | ||
model_group_stats[(model, token_group_desc)] = { | ||
"mean": np.mean(group_logprobs[token_group_desc]), | ||
"median": np.median(group_logprobs[token_group_desc]), | ||
"min": np.min(group_logprobs[token_group_desc]), | ||
"max": np.max(group_logprobs[token_group_desc]), | ||
"25th": np.percentile(group_logprobs[token_group_desc], 25), | ||
"75th": np.percentile(group_logprobs[token_group_desc], 75), | ||
} | ||
return model_group_stats |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
corpus_dataset = "delphi-suite/tinystories-v2-clean" | ||
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0" | ||
|
||
LLAMA2_MODELS = [ | ||
"delphi-llama2-100k", | ||
"delphi-llama2-200k", | ||
"delphi-llama2-400k", | ||
"delphi-llama2-800k", | ||
"delphi-llama2-1.6m", | ||
"delphi-llama2-3.2m", | ||
"delphi-llama2-6.4m", | ||
"delphi-llama2-12.8m", | ||
"delphi-llama2-25.6m", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
from jaxtyping import Float, Int | ||
from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
|
||
from delphi.eval import constants | ||
|
||
|
||
def get_all_logprobs( | ||
model: Callable, input_ids: Int[torch.Tensor, "batch seq"] | ||
|
@@ -87,3 +89,16 @@ def tokenize( | |
Int[torch.Tensor, "seq"], | ||
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], | ||
) | ||
|
||
|
||
def load_logprob_dataset(model: str): | ||
return cast( | ||
Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs") | ||
) | ||
jaidhyani marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]: | ||
return { | ||
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm worried how much we have to use casting with Dataset objects :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the dataset library clearly wasn't designed with type-guarantees in mind. Building a library of functions that take care of the casting and other stuff we don't want to worry about in our day-to-day might be a good idea. Ticket would be something like "identify repetitive casts in the codebase and replace them with library functions". |
||
for model in constants.LLAMA2_MODELS | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, now that the static dir PR is in