Skip to content

Commit

Permalink
End-to-end evaluation demo (#32)
Browse files Browse the repository at this point in the history
end-to-end evals visualization demo
  • Loading branch information
jaidhyani authored Feb 27, 2024
1 parent 43b76f7 commit a48370b
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
},
"python.analysis.typeCheckingMode": "basic",
"isort.args": [
"--profile black"
"--profile",
"black"
],
"black-formatter.importStrategy": "fromEnvironment",
}
133 changes: 133 additions & 0 deletions notebooks/end2end_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from typing import cast\n",
"import pickle\n",
"from collections import defaultdict\n",
"\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"from delphi.constants import STATIC_ASSETS_DIR\n",
"from delphi.eval import utils\n",
"from delphi.eval import constants\n",
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
"\n",
"# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
"from delphi.eval.token_labelling import TOKEN_LABELS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# load data\n",
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
"\n",
"# TODO: convert to use static paths\n",
"# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
"# token_groups = pickle.load(f)\n",
"# model_group_stats = calc_model_group_stats(\n",
"# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n",
"# )\n",
"with open(f\"{STATIC_ASSETS_DIR}/model_group_stats.pkl\", \"rb\") as f:\n",
" model_group_stats = pickle.load(f)\n",
"\n",
"logprob_datasets = utils.load_logprob_datasets(\"validation\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Visualization"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0f8846898fbb4a1b9e872ff6511acd3d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"performance_data = defaultdict(dict)\n",
"for model in constants.LLAMA2_MODELS:\n",
" for token_group_desc in TOKEN_LABELS:\n",
" if (model, token_group_desc) not in model_group_stats:\n",
" continue\n",
" stats = model_group_stats[(model, token_group_desc)]\n",
" performance_data[model][token_group_desc] = (\n",
" -stats[\"median\"],\n",
" -stats[\"75th\"],\n",
" -stats[\"25th\"],\n",
" )\n",
"\n",
"visualize_per_token_category(\n",
" performance_data,\n",
" log_scale=True,\n",
" bg_color=\"LightGrey\",\n",
" line_color=\"Red\",\n",
" marker_color=\"Orange\",\n",
" bar_color=\"Green\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tinyevals",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ isort==5.13.2
spacy==3.7.2
chardet==5.2.0
sentencepiece==0.1.99
protobuf==4.25.2
protobuf==4.25.2
plotly==5.18.0
spacy-transformers==1.3.4
1 change: 1 addition & 0 deletions scripts/map_tokens.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import argparse
import os
import pickle

from delphi.constants import STATIC_ASSETS_DIR
Expand Down
54 changes: 54 additions & 0 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np


def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprobs_by_dataset: dict[str, list[list[float]]],
token_labels_by_token: dict[int, dict[str, bool]],
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
"""
For each (model, token group) pair, calculate useful stats (for visualization)
args:
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. constants.LLAMA2_MODELS
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
returns: a dict of (model, token group) pairs to a dict of stats,
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
but it's better to be explicit
stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in logprobs_by_dataset:
group_logprobs = {}
print(f"Processing model {model}")
dataset = logprobs_by_dataset[model]
for ix_doc_lp, document_lps in enumerate(dataset):
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
for token_group_desc in token_labels:
if token_labels_by_token[token][token_group_desc]:
if token_group_desc not in group_logprobs:
group_logprobs[token_group_desc] = []
group_logprobs[token_group_desc].append(logprob)
for token_group_desc in token_labels:
if token_group_desc in group_logprobs:
model_group_stats[(model, token_group_desc)] = {
"mean": np.mean(group_logprobs[token_group_desc]),
"median": np.median(group_logprobs[token_group_desc]),
"min": np.min(group_logprobs[token_group_desc]),
"max": np.max(group_logprobs[token_group_desc]),
"25th": np.percentile(group_logprobs[token_group_desc], 25),
"75th": np.percentile(group_logprobs[token_group_desc], 75),
}
return model_group_stats
14 changes: 14 additions & 0 deletions src/delphi/eval/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
corpus_dataset = "delphi-suite/tinystories-v2-clean"
tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0"

LLAMA2_MODELS = [
"delphi-llama2-100k",
"delphi-llama2-200k",
"delphi-llama2-400k",
"delphi-llama2-800k",
"delphi-llama2-1.6m",
"delphi-llama2-3.2m",
"delphi-llama2-6.4m",
"delphi-llama2-12.8m",
"delphi-llama2-25.6m",
]
13 changes: 13 additions & 0 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from delphi.eval import constants


def get_all_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
Expand Down Expand Up @@ -87,3 +89,14 @@ def tokenize(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)


def load_logprob_dataset(model: str) -> Dataset:
return load_dataset(f"transcendingvictor/{model}-validation-logprobs") # type: ignore


def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]:
return {
model: cast(dict, load_logprob_dataset(model)[split])["logprobs"]
for model in constants.LLAMA2_MODELS
}
6 changes: 6 additions & 0 deletions src/delphi/eval/vis_per_token_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def visualize_per_token_category(
categories = list(input[model_names[0]].keys())
category = categories[0]

def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]:
return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)]

def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
x = np.array([input[name][category] for name in model_names]).T
means, err_lo, err_hi = x[0], x[1], x[2]
Expand All @@ -32,6 +35,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
size=15,
line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2),
),
hovertext=get_hovertexts(means, err_low, err_hi),
hoverinfo="text+x",
),
layout=go.Layout(
yaxis=dict(
Expand All @@ -55,6 +60,7 @@ def response(change):
g.data[0].y = means
g.data[0].error_y["array"] = err_hi
g.data[0].error_y["arrayminus"] = err_lo
g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi)

selected_category.observe(response, names="value")

Expand Down
10 changes: 10 additions & 0 deletions src/delphi/static/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# TODO: move this to delphi/static
# Static Data Files


## `token_map.pkl`
pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs.

## `model_group_stats.pkl`
useful statistics for data visualization of (model, tokengroup) pairs; dict of (model, tokengroup) to dict of (str, float):
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
Binary file added src/delphi/static/model_group_stats.pkl
Binary file not shown.

0 comments on commit a48370b

Please sign in to comment.