diff --git a/.vscode/settings.json b/.vscode/settings.json index b5f4dfd1..8f9d001f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -8,7 +8,8 @@ }, "python.analysis.typeCheckingMode": "basic", "isort.args": [ - "--profile black" + "--profile", + "black" ], "black-formatter.importStrategy": "fromEnvironment", } \ No newline at end of file diff --git a/notebooks/end2end_demo.ipynb b/notebooks/end2end_demo.ipynb new file mode 100644 index 00000000..f08aba38 --- /dev/null +++ b/notebooks/end2end_demo.ipynb @@ -0,0 +1,133 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from typing import cast\n", + "import pickle\n", + "from collections import defaultdict\n", + "\n", + "from datasets import load_dataset, Dataset\n", + "\n", + "from delphi.constants import STATIC_ASSETS_DIR\n", + "from delphi.eval import utils\n", + "from delphi.eval import constants\n", + "from delphi.eval.vis_per_token_model import visualize_per_token_category\n", + "\n", + "# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n", + "from delphi.eval.token_labelling import TOKEN_LABELS" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# load data\n", + "tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n", + "\n", + "# TODO: convert to use static paths\n", + "# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n", + "# token_groups = pickle.load(f)\n", + "# model_group_stats = calc_model_group_stats(\n", + "# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n", + "# )\n", + "with open(f\"{STATIC_ASSETS_DIR}/model_group_stats.pkl\", \"rb\") as f:\n", + " model_group_stats = pickle.load(f)\n", + "\n", + "logprob_datasets = utils.load_logprob_datasets(\"validation\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0f8846898fbb4a1b9e872ff6511acd3d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "performance_data = defaultdict(dict)\n", + "for model in constants.LLAMA2_MODELS:\n", + " for token_group_desc in TOKEN_LABELS:\n", + " if (model, token_group_desc) not in model_group_stats:\n", + " continue\n", + " stats = model_group_stats[(model, token_group_desc)]\n", + " performance_data[model][token_group_desc] = (\n", + " -stats[\"median\"],\n", + " -stats[\"75th\"],\n", + " -stats[\"25th\"],\n", + " )\n", + "\n", + "visualize_per_token_category(\n", + " performance_data,\n", + " log_scale=True,\n", + " bg_color=\"LightGrey\",\n", + " line_color=\"Red\",\n", + " marker_color=\"Orange\",\n", + " bar_color=\"Green\",\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tinyevals", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 5fdc84c0..e14ef757 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ isort==5.13.2 spacy==3.7.2 chardet==5.2.0 sentencepiece==0.1.99 -protobuf==4.25.2 \ No newline at end of file +protobuf==4.25.2 +plotly==5.18.0 +spacy-transformers==1.3.4 \ No newline at end of file diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py old mode 100644 new mode 100755 index 2acea2da..832ac0da --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import os import pickle from delphi.constants import STATIC_ASSETS_DIR diff --git a/src/delphi/eval/calc_model_group_stats.py b/src/delphi/eval/calc_model_group_stats.py new file mode 100644 index 00000000..d9c5d4c1 --- /dev/null +++ b/src/delphi/eval/calc_model_group_stats.py @@ -0,0 +1,54 @@ +import numpy as np + + +def calc_model_group_stats( + tokenized_corpus_dataset: list, + logprobs_by_dataset: dict[str, list[list[float]]], + token_labels_by_token: dict[int, dict[str, bool]], + token_labels: list[str], +) -> dict[tuple[str, str], dict[str, float]]: + """ + For each (model, token group) pair, calculate useful stats (for visualization) + + args: + - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"] + - logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]} + - token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...} + - models: a list of model names, e.g. constants.LLAMA2_MODELS + - token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...] + + returns: a dict of (model, token group) pairs to a dict of stats, + e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} + + Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, + but it's better to be explicit + + stats calculated: mean, median, min, max, 25th percentile, 75th percentile + """ + model_group_stats = {} + for model in logprobs_by_dataset: + group_logprobs = {} + print(f"Processing model {model}") + dataset = logprobs_by_dataset[model] + for ix_doc_lp, document_lps in enumerate(dataset): + tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"] + for ix_token, token in enumerate(tokens): + if ix_token == 0: # skip the first token, which isn't predicted + continue + logprob = document_lps[ix_token] + for token_group_desc in token_labels: + if token_labels_by_token[token][token_group_desc]: + if token_group_desc not in group_logprobs: + group_logprobs[token_group_desc] = [] + group_logprobs[token_group_desc].append(logprob) + for token_group_desc in token_labels: + if token_group_desc in group_logprobs: + model_group_stats[(model, token_group_desc)] = { + "mean": np.mean(group_logprobs[token_group_desc]), + "median": np.median(group_logprobs[token_group_desc]), + "min": np.min(group_logprobs[token_group_desc]), + "max": np.max(group_logprobs[token_group_desc]), + "25th": np.percentile(group_logprobs[token_group_desc], 25), + "75th": np.percentile(group_logprobs[token_group_desc], 75), + } + return model_group_stats diff --git a/src/delphi/eval/constants.py b/src/delphi/eval/constants.py new file mode 100644 index 00000000..30b3e36b --- /dev/null +++ b/src/delphi/eval/constants.py @@ -0,0 +1,14 @@ +corpus_dataset = "delphi-suite/tinystories-v2-clean" +tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0" + +LLAMA2_MODELS = [ + "delphi-llama2-100k", + "delphi-llama2-200k", + "delphi-llama2-400k", + "delphi-llama2-800k", + "delphi-llama2-1.6m", + "delphi-llama2-3.2m", + "delphi-llama2-6.4m", + "delphi-llama2-12.8m", + "delphi-llama2-25.6m", +] diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index 16b26e4c..d0611afb 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -6,6 +6,8 @@ from jaxtyping import Float, Int from transformers import PreTrainedModel, PreTrainedTokenizerBase +from delphi.eval import constants + def get_all_logprobs( model: Callable, input_ids: Int[torch.Tensor, "batch seq"] @@ -87,3 +89,14 @@ def tokenize( Int[torch.Tensor, "seq"], tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], ) + + +def load_logprob_dataset(model: str) -> Dataset: + return load_dataset(f"transcendingvictor/{model}-validation-logprobs") # type: ignore + + +def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]: + return { + model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] + for model in constants.LLAMA2_MODELS + } diff --git a/src/delphi/eval/vis_per_token_model.py b/src/delphi/eval/vis_per_token_model.py index a8e269fe..618840b0 100644 --- a/src/delphi/eval/vis_per_token_model.py +++ b/src/delphi/eval/vis_per_token_model.py @@ -10,6 +10,9 @@ def visualize_per_token_category( categories = list(input[model_names[0]].keys()) category = categories[0] + def get_hovertexts(mid: np.ndarray, lo: np.ndarray, hi: np.ndarray) -> list[str]: + return [f"Loss: {m:.3f} ({l:.3f}, {h:.3f})" for m, l, h in zip(mid, lo, hi)] + def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: x = np.array([input[name][category] for name in model_names]).T means, err_lo, err_hi = x[0], x[1], x[2] @@ -32,6 +35,8 @@ def get_plot_values(category: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]: size=15, line=dict(color=kwargs.get("line_color", "MediumPurple"), width=2), ), + hovertext=get_hovertexts(means, err_low, err_hi), + hoverinfo="text+x", ), layout=go.Layout( yaxis=dict( @@ -55,6 +60,7 @@ def response(change): g.data[0].y = means g.data[0].error_y["array"] = err_hi g.data[0].error_y["arrayminus"] = err_lo + g.data[0].hovertext = get_hovertexts(means, err_lo, err_hi) selected_category.observe(response, names="value") diff --git a/src/delphi/static/README.md b/src/delphi/static/README.md new file mode 100644 index 00000000..815b0c42 --- /dev/null +++ b/src/delphi/static/README.md @@ -0,0 +1,10 @@ +# TODO: move this to delphi/static +# Static Data Files + + +## `token_map.pkl` +pickle file: All locations of all tokens. dict of token to list of (doc, pos) pairs. + +## `model_group_stats.pkl` +useful statistics for data visualization of (model, tokengroup) pairs; dict of (model, tokengroup) to dict of (str, float): +e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...} \ No newline at end of file diff --git a/src/delphi/static/model_group_stats.pkl b/src/delphi/static/model_group_stats.pkl new file mode 100644 index 00000000..7f5297c4 Binary files /dev/null and b/src/delphi/static/model_group_stats.pkl differ