Skip to content

Commit

Permalink
Verify that tokens are valid
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 16, 2024
1 parent 8ecf192 commit b3feef9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
12 changes: 11 additions & 1 deletion crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl BytePairEncoding {
split_table.push((id as u32, id as u32));
}
}
Self {
let bpe = Self {
all_tokens,
token_starts,
bytes_hash_to_token,
Expand All @@ -314,7 +314,17 @@ impl BytePairEncoding {
pair_lookup,
split_table,
hash_factor,
};
for token_id in 0..bpe.num_tokens() as u32 {
let bytes = bpe.token_bytes(token_id);
let tokens = bpe.encode_via_bitfield(bytes);
assert_eq!(
tokens,
vec![token_id],
"token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself"
);
}
bpe
}

/// Return the number of tokens in this BPE dictionary.
Expand Down
12 changes: 7 additions & 5 deletions crates/bpe/tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ mod tests {
/// This test produces the output for the encoding example in the README.
#[test]
fn readme_example() {
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
let bpe = BytePairEncoding::from_dictionary(tokens, None);
let text = "abacb";
let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"];
let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None);
let text = "abacbb";
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
let all_prefix_tokens = prefixes
.iter()
.map(|prefix| {
bpe.encode_via_backtracking(prefix.as_bytes())
.into_iter()
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
.map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap())
.collect_vec()
})
.collect_vec();
Expand All @@ -48,6 +48,8 @@ mod tests {
.map(|tokens| tokens.last().unwrap())
.collect_vec();

println!("Token set: `{}`\n", tokens.join(" "));

println!("All tokens for each prefix of `{text}`:\n");
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
println!(
Expand All @@ -67,7 +69,7 @@ mod tests {
}
println!();

println!("Tokenization of `{text}`:\n");
println!("Encoding using last tokens of `{text}`:\n");
let mut remaining = text.len();
while remaining > 0 {
let prefix = &text[..remaining];
Expand Down

0 comments on commit b3feef9

Please sign in to comment.