Skip to content

Commit

Permalink
Add calc_model_group_stats function to calculate useful stats for vis…
Browse files Browse the repository at this point in the history
…ualization
  • Loading branch information
jaidhyani committed Feb 23, 2024
1 parent 53e0eb0 commit 043c408
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/delphi/eval/calc_model_group_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from datasets import Dataset
import numpy as np


def calc_model_group_stats(
tokenized_corpus_dataset: list,
logprob_datasets: dict[str, list[list[float]]],
token_groups: dict[int, dict[str, bool]],
models: list[str],
token_labels: list[str],
) -> dict[tuple[str, str], dict[str, float]]:
"""
For each (model, token group) pair, calculate useful stats (for visualization)
args:
- tokenized_corpus_dataset: the tokenized corpus dataset, e.g. load_dataset(constants.tokenized_corpus_dataset))["validation"]
- logprob_datasets: a dict of lists of logprobs, e.g. {"llama2": load_dataset("transcendingvictor/llama2-validation-logprobs")["validation"]["logprobs"]}
- token_groups: a dict of token groups, e.g. {0: {"Is Noun": True, "Is Verb": False, ...}, 1: {...}, ...}
- models: a list of model names, e.g. ["llama2", "gpt2", ...]
- token_labels: a list of token group descriptions, e.g. ["Is Noun", "Is Verb", ...]
returns: a dict of (model, token group) pairs to a dict of stats,
e.g. {("llama2", "Is Noun"): {"mean": -0.5, "median": -0.4, "min": -0.1, "max": -0.9, "25th": -0.3, "75th": -0.7}, ...}
Technically `models` and `token_labels` are redundant, as they are also keys in `logprob_datasets` and `token_groups`,
but it's better to be explicit
stats calculated: mean, median, min, max, 25th percentile, 75th percentile
"""
model_group_stats = {}
for model in models:
group_logprobs = {}
print(f"Processing model {model}")
dataset = logprob_datasets[model]
for ix_doc_lp, document_lps in enumerate(dataset):
tokens = tokenized_corpus_dataset[ix_doc_lp]["tokens"]
for ix_token, token in enumerate(tokens):
if ix_token == 0: # skip the first token, which isn't predicted
continue
logprob = document_lps[ix_token]
for token_group_desc in token_labels:
if token_groups[token][token_group_desc]:
if token_group_desc not in group_logprobs:
group_logprobs[token_group_desc] = []
group_logprobs[token_group_desc].append(logprob)
for token_group_desc in token_labels:
if token_group_desc in group_logprobs:
model_group_stats[(model, token_group_desc)] = {
"mean": np.mean(group_logprobs[token_group_desc]),
"median": np.median(group_logprobs[token_group_desc]),
"min": np.min(group_logprobs[token_group_desc]),
"max": np.max(group_logprobs[token_group_desc]),
"25th": np.percentile(group_logprobs[token_group_desc], 25),
"75th": np.percentile(group_logprobs[token_group_desc], 75),
}
return model_group_stats

0 comments on commit 043c408

Please sign in to comment.