-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add utility functions for text processing and visualization #17
Changes from 5 commits
1bd3bcd
39d49f1
f31f494
87b9057
14ffb63
d5129f3
06242e1
9ad1bc0
7061586
5ddbdb7
cd6f9da
8d42db3
90a2a16
0f213bc
5793293
ecc86be
1b65f8b
c934d48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,11 @@ | ||
from collections.abc import Callable | ||
from typing import List, cast | ||
from typing import cast | ||
|
||
import torch | ||
from datasets import Dataset, load_dataset | ||
from jaxtyping import Float, Int | ||
from transformers import PreTrainedModel, PreTrainedTokenizerBase | ||
|
||
ALLOWED_CHARS = set( | ||
" \n\"'(),.:?!0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | ||
) | ||
|
||
|
||
def get_all_logprobs( | ||
model: Callable, input_ids: Int[torch.Tensor, "batch seq"] | ||
|
@@ -33,12 +29,39 @@ def gather_logprobs( | |
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1) | ||
|
||
|
||
def get_next_logprobs( | ||
model: Callable, input_ids: Int[torch.Tensor, "batch seq"] | ||
) -> Float[torch.Tensor, "batch shorter_seq"]: | ||
def get_all_and_next_logprobs( | ||
model: Callable, | ||
input_ids: Int[torch.Tensor, "batch seq"], | ||
) -> tuple[ | ||
Float[torch.Tensor, "batch shorter_seq vocab"], | ||
Float[torch.Tensor, "batch shorter_seq"], | ||
]: | ||
logprobs = get_all_logprobs(model, input_ids[:, :-1]) | ||
next_tokens = input_ids[:, 1:] | ||
return gather_logprobs(logprobs, next_tokens) | ||
return logprobs, gather_logprobs(logprobs, next_tokens) | ||
|
||
|
||
def get_all_and_next_logprobs_single( | ||
model: Callable, | ||
input_ids: Int[torch.Tensor, "seq"], | ||
) -> tuple[ | ||
Float[torch.Tensor, "shorter_seq vocab"], | ||
Float[torch.Tensor, "shorter_seq"], | ||
]: | ||
all_logprobs, next_logprobs = get_all_and_next_logprobs( | ||
model, input_ids.unsqueeze(0) | ||
) | ||
return all_logprobs[0], next_logprobs[0] | ||
|
||
|
||
def get_next_and_top_k_probs( | ||
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3 | ||
) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]: | ||
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids) | ||
all_probs = torch.exp(all_logprobs) | ||
next_probs = torch.exp(next_logprobs) | ||
top_k = torch.topk(all_probs, k, dim=-1) | ||
return next_probs, top_k | ||
Comment on lines
+57
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you using this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the demo notebook |
||
|
||
|
||
def load_validation_dataset(dataset_name: str) -> Dataset: | ||
|
@@ -56,16 +79,6 @@ def load_validation_dataset(dataset_name: str) -> Dataset: | |
return cast(Dataset, dataset) | ||
|
||
|
||
def load_text_from_dataset(dataset: Dataset) -> list[str]: | ||
text = [] | ||
for sample_txt in dataset["story"]: | ||
# encoding issues and rare weird prompts | ||
if not set(sample_txt).issubset(ALLOWED_CHARS): | ||
continue | ||
text.append(sample_txt) | ||
return text | ||
|
||
|
||
def tokenize( | ||
tokenizer: PreTrainedTokenizerBase, sample_txt: str | ||
) -> Int[torch.Tensor, "seq"]: | ||
|
@@ -74,22 +87,3 @@ def tokenize( | |
Int[torch.Tensor, "seq"], | ||
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], | ||
) | ||
|
||
|
||
def get_correct_and_all_probs( | ||
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"] | ||
) -> tuple[Float[torch.Tensor, "next_seq"], Float[torch.Tensor, "next_seq vocab"]]: | ||
"""Get probabilities for the actual next token and for all predictions""" | ||
# remove the first token (the bos token) | ||
probs = get_single_logprobs(model, sample_tok)[1:] | ||
correct_probs = probs[range(len(probs)), sample_tok[1:]] | ||
return correct_probs, probs | ||
|
||
|
||
def get_correct_and_top_probs( | ||
model: PreTrainedModel, sample_tok: Int[torch.Tensor, "seq"], top_k: int = 3 | ||
) -> tuple[Float[torch.Tensor, "next_seq"], torch.return_types.topk]: | ||
"""Get probabilities for the actual next token and for top k predictions""" | ||
correct_probs, probs = get_correct_and_all_probs(model, sample_tok) | ||
top_k_probs = torch.topk(probs, top_k, dim=-1) | ||
return correct_probs, top_k_probs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from delphi.eval.compare_models import NextTokenStats, compare_models | ||
from delphi.eval.utils import load_validation_dataset, tokenize | ||
|
||
|
||
def test_compare_models(): | ||
with torch.set_grad_enabled(False): | ||
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M") | ||
model_instruct = AutoModelForCausalLM.from_pretrained( | ||
"roneneldan/TinyStories-Instruct-1M" | ||
) | ||
ds_txt = load_validation_dataset("tinystories-v2-clean")["story"] | ||
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M") | ||
sample_tok = tokenize(tokenizer, ds_txt[0]) | ||
K = 3 | ||
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K) | ||
# ignore the first element comparison | ||
assert model_comparison[0] is None | ||
assert isinstance(model_comparison[1], NextTokenStats) | ||
assert len(model_comparison) == sample_tok.shape[0] | ||
assert len(model_comparison[1].topk) == K |
jaidhyani marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this probs or logprobs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
! Good catch!