Skip to content

Commit

Permalink
Cleanup on end2end
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Feb 23, 2024
1 parent 3005267 commit 53119ce
Showing 1 changed file with 39 additions and 188 deletions.
227 changes: 39 additions & 188 deletions notebooks/end2end_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,254 +4,105 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Important Part"
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jaidhyani/miniforge3/envs/tinyevals/lib/python3.10/site-packages/beartype/_util/hint/pep/utilpeptest.py:311: BeartypeDecorHintPep585DeprecationWarning: PEP 484 type hint typing.Callable deprecated by PEP 585. This hint is scheduled for removal in the first Python version released after October 5th, 2025. To resolve this, import this hint from \"beartype.typing\" rather than \"typing\". For further commentary and alternatives, see also:\n",
" https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations\n",
" warn(\n"
]
}
],
"outputs": [],
"source": [
"from typing import cast\n",
"from dataclasses import dataclass\n",
"import pickle\n",
"import numpy as np\n",
"import pandas as pd\n",
"from collections import defaultdict\n",
"\n",
"from datasets import load_dataset, Dataset\n",
"\n",
"from delphi.eval import utils\n",
"from delphi.eval import constants\n",
"from delphi.eval.vis_per_token_model import visualize_per_token_category"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def calc_model_group_stats(\n",
" tokenized_corpus_dataset: list,\n",
" logprob_datasets: dict[str, Dataset],\n",
" token_groups: dict[int, dict[str, bool]],\n",
" models: list[str],\n",
" token_labels: list[str]\n",
") -> dict[tuple[str, str], dict[str, float]]:\n",
" \"\"\"\n",
" For each (model, token group) pair, calculate useful stats (for visualization)\n",
"\n",
" args:\n",
" - tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
" - logprob_datasets: a dict of logprob datasets, e.g. {\"llama2\": load_dataset(\"transcendingvictor/llama2-validation-logprobs\")[\"validation\"][\"logprobs\"]}\n",
" - token_groups: a dict of token groups, e.g. {0: {\"Is Noun\": True, \"Is Verb\": False, ...}, 1: {...}, ...}\n",
" - models: a list of model names, e.g. [\"llama2\", \"gpt2\", ...]\n",
" - token_labels: a list of token group descriptions, e.g. [\"Is Noun\", \"Is Verb\", ...]\n",
"\n",
" returns: a dict of (model, token group) pairs to a dict of stats, \n",
" e.g. {(\"llama2\", \"Is Noun\"): {\"mean\": -0.5, \"median\": -0.4, \"min\": -0.1, \"max\": -0.9, \"25th\": -0.3, \"75th\": -0.7}, ...}\n",
"\n",
" Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`, \n",
" but it's better to be explicit\n",
" \n",
" stats calculated: mean, median, min, max, 25th percentile, 75th percentile\n",
" \"\"\"\n",
" model_group_stats = {}\n",
" for model in models:\n",
" group_logprobs = {}\n",
" print(f\"Processing model {model}\")\n",
" dataset = logprob_datasets[model]\n",
" for ix_doc_lp, document_lps in enumerate(dataset):\n",
" tokens = tokenized_corpus_dataset[ix_doc_lp][\"tokens\"]\n",
" for ix_token, token in enumerate(tokens):\n",
" if ix_token == 0: # skip the first token, which isn't predicted\n",
" continue\n",
" logprob = document_lps[ix_token]\n",
" for token_group_desc in token_labels:\n",
" if token_groups[token][token_group_desc]:\n",
" if token_group_desc not in group_logprobs:\n",
" group_logprobs[token_group_desc] = []\n",
" group_logprobs[token_group_desc].append(logprob)\n",
" for token_group_desc in token_labels:\n",
" if token_group_desc in group_logprobs:\n",
" model_group_stats[(model, token_group_desc)] = {\n",
" \"mean\": np.mean(group_logprobs[token_group_desc]),\n",
" \"median\": np.median(group_logprobs[token_group_desc]),\n",
" \"min\": np.min(group_logprobs[token_group_desc]),\n",
" \"max\": np.max(group_logprobs[token_group_desc]),\n",
" \"25th\": np.percentile(group_logprobs[token_group_desc], 25),\n",
" \"75th\": np.percentile(group_logprobs[token_group_desc], 75),\n",
" }\n",
" print()\n",
" return model_group_stats"
"from delphi.eval.vis_per_token_model import visualize_per_token_category\n",
"# from delphi.eval.calc_model_group_stats import calc_model_group_stats\n",
"from delphi.eval.token_labelling import TOKEN_LABELS"
]
},
{
"cell_type": "code",
"execution_count": 4,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]"
"# Data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#TODO: convert to use static paths\n",
"# load data\n",
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
"\n",
"with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
" token_groups = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# token groups is a dict of (int -> dict(str -> bool)). The top keys are the token ids, and the values are dicts of token group names to boolean values\n",
"# of whether the token is in that group. We want to turn this into a dict of (str -> list(int)) where the keys are the token group names and the values are lists of token ids\n",
"# TODO: convert to use static paths\n",
"# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
"# token_groups = pickle.load(f)\n",
"# model_group_stats = calc_model_group_stats(\n",
"# tokenized_corpus_dataset, logprob_datasets, token_groups, constants.LLAMA2_MODELS, token_groups[0].keys()\n",
"# )\n",
"with open(\"../data/model_group_stats.pkl\", \"rb\") as f:\n",
" model_group_stats = pickle.load(f)\n",
"\n",
"token_group_members = {}\n",
"for token_id, group_dict in token_groups.items():\n",
" for group_name, is_member in group_dict.items():\n",
" if is_member:\n",
" if group_name not in token_group_members:\n",
" token_group_members[group_name] = []\n",
" token_group_members[group_name].append(token_id)\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"logprob_datasets = {}\n",
"for model in constants.LLAMA2_MODELS:\n",
" logprob_datasets[model] = cast(\n",
" dict,\n",
" cast(Dataset, load_dataset(f\"transcendingvictor/{model}-validation-logprobs\"))[\n",
" \"validation\"\n",
" ],\n",
" )[\"logprobs\"]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"token_group_descriptions = list(token_groups[0].keys())"
"logprob_datasets = utils.load_logprob_datasets(\"validation\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"cell_type": "markdown",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processing model delphi-llama2-100k\n",
"\n",
"Processing model delphi-llama2-200k\n",
"\n",
"Processing model delphi-llama2-400k\n",
"\n",
"Processing model delphi-llama2-800k\n",
"\n",
"Processing model delphi-llama2-1.6m\n",
"\n",
"Processing model delphi-llama2-3.2m\n",
"\n",
"Processing model delphi-llama2-6.4m\n",
"\n",
"Processing model delphi-llama2-12.8m\n",
"\n",
"Processing model delphi-llama2-25.6m\n",
"\n"
]
}
],
"source": [
"tokenized_corpus_dataset = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))[\"validation\"]\n",
"\n",
"model_group_stats = calc_model_group_stats(\n",
" tokenized_corpus_dataset, logprob_datasets, token_groups, constants.LLAMA2_MODELS, token_groups[0].keys()\n",
")"
"# Visualization"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "06cfd94029aa4804b8ae8fb3a71d13f7",
"model_id": "9323c96bf6834ceb91f751f45512ea85",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Dropdown(description='Token Category:', options=('Capitalized', 'Is Determiner', 'Is Interjunct…"
]
},
"execution_count": 12,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"performance_data = defaultdict(dict)\n",
"for model in constants.LLAMA2_MODELS:\n",
" for token_group_desc in token_group_descriptions:\n",
" for token_group_desc in TOKEN_LABELS:\n",
" if (model, token_group_desc) not in model_group_stats:\n",
" continue\n",
" stats = model_group_stats[(model, token_group_desc)]\n",
" performance_data[model][token_group_desc] = (-stats[\"median\"], -stats[\"75th\"], -stats[\"25th\"])\n",
" performance_data[model][token_group_desc] = (\n",
" -stats[\"median\"],\n",
" -stats[\"75th\"],\n",
" -stats[\"25th\"],\n",
" )\n",
"\n",
"visualize_per_token_category(performance_data, log_scale=True, bg_color='LightGrey', line_color=\"Red\", marker_color='Orange', bar_color='Green')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# with open(\"../data/model_group_stats.pkl\", \"wb\") as f:\n",
"# pickle.dump(model_group_stats, f)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# import pickle\n",
"# with open(\"../data/model_group_stats.pkl\", \"rb\") as f:\n",
"# model_group_stats = pickle.load(f)"
"visualize_per_token_category(\n",
" performance_data,\n",
" log_scale=True,\n",
" bg_color=\"LightGrey\",\n",
" line_color=\"Red\",\n",
" marker_color=\"Orange\",\n",
" bar_color=\"Green\",\n",
")"
]
}
],
Expand Down

0 comments on commit 53119ce

Please sign in to comment.