Skip to content

Commit

Permalink
Refactor calc_model_group_stats function
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidhyani committed Feb 23, 2024
1 parent 067c317 commit 3d9a888
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion notebooks/end2end_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"# with open(\"../src/delphi/eval/labelled_token_ids_dict.pkl\", \"rb\") as f:\n",
"# token_groups = pickle.load(f)\n",
"# model_group_stats = calc_model_group_stats(\n",
"# tokenized_corpus_dataset, logprob_datasets, token_groups, constants.LLAMA2_MODELS, token_groups[0].keys()\n",
"# tokenized_corpus_dataset, logprob_datasets, token_groups, token_groups[0].keys()\n",
"# )\n",
"with open(\"../data/model_group_stats.pkl\", \"rb\") as f:\n",
" model_group_stats = pickle.load(f)\n",
Expand Down
5 changes: 2 additions & 3 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprob_datasets: dict[str, list[list[float]]],
token_groups: dict[int, dict[str, bool]],
models: list[str],
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
"""
Expand All @@ -16,7 +15,7 @@ def calc_model_group_stats(
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. ["llama2", "gpt2", ...]
- models: a list of model names, e.g. constants.LLAMA2_MODELS
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
returns: a dict of (model, token group) pairs to a dict of stats,
Expand All @@ -28,7 +27,7 @@ def calc_model_group_stats(
stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in models:
for model in logprob_datasets:
group_logprobs = {}
print(f"Processing model {model}")
dataset = logprob_datasets[model]
Expand Down

0 comments on commit 3d9a888

Please sign in to comment.