Skip to content

Commit

Permalink
Merge pull request #31 from github/verify-tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen authored Oct 17, 2024
2 parents 8ecf192 + efaf552 commit 5b127c9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
33 changes: 17 additions & 16 deletions crates/bpe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,32 @@ Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`,
## Novel Algorithm

At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision.
For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text.
For instance, the sequence `abacb` would be encoded as `ab a cb` when the dictionary contains the tokens `a b c ab cb ac bb cbb acbb` ordered by frequency. But appending a single character `abacbb` would result in a pretty different tokenization: `ab acbb`. So without looking ahead it seems impossible to properly tokenize the text.

The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get:
The solution is to track the encodings of ALL text prefixes. For our example `abacbb` we would get:

- `a` ------> `a`
- `ab` -----> `ab`
- `aba` ----> `ab a`
- `abac` ---> `ab ac`
- `abacb` --> `ab a cb`
- `a` -------> `a`
- `ab` ------> `ab`
- `aba` -----> `ab a`
- `abac` ----> `ab ac`
- `abacb` ---> `ab a cb`
- `abacbb` --> `ab acbb`

This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered:

- `a` ------> `a`
- `ab` -----> `ab`
- `aba` ----> `a`
- `abac` ---> `ac`
- `abacb` --> `cb`
- `a` -------> `a`
- `ab` ------> `ab`
- `aba` -----> `a`
- `abac` ----> `ac`
- `abacb` ---> `cb`
- `abacbb` --> `acbb`

In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached.

For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order:
For our example prefix `abacbb`, this procedure executes the following steps and determines the correct encoding in reverse order:

- `abacb` -> `cb`
- `aba` ---> `a`
- `ab` ----> `ab`
- `abacbb` --> `acbb`
- `ab` ------> `ab`
- `<empty>`

The actual challenge is to determine for every prefix this last token efficiently.
Expand Down
12 changes: 11 additions & 1 deletion crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl BytePairEncoding {
split_table.push((id as u32, id as u32));
}
}
Self {
let bpe = Self {
all_tokens,
token_starts,
bytes_hash_to_token,
Expand All @@ -314,7 +314,17 @@ impl BytePairEncoding {
pair_lookup,
split_table,
hash_factor,
};
for token_id in 0..bpe.num_tokens() as u32 {
let bytes = bpe.token_bytes(token_id);
let tokens = bpe.encode_via_bitfield(bytes);
assert_eq!(
tokens,
vec![token_id],
"token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself"
);
}
bpe
}

/// Return the number of tokens in this BPE dictionary.
Expand Down
12 changes: 7 additions & 5 deletions crates/bpe/tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ mod tests {
/// This test produces the output for the encoding example in the README.
#[test]
fn readme_example() {
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
let bpe = BytePairEncoding::from_dictionary(tokens, None);
let text = "abacb";
let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"];
let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None);
let text = "abacbb";
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
let all_prefix_tokens = prefixes
.iter()
.map(|prefix| {
bpe.encode_via_backtracking(prefix.as_bytes())
.into_iter()
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
.map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap())
.collect_vec()
})
.collect_vec();
Expand All @@ -48,6 +48,8 @@ mod tests {
.map(|tokens| tokens.last().unwrap())
.collect_vec();

println!("Token set: `{}`\n", tokens.join(" "));

println!("All tokens for each prefix of `{text}`:\n");
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
println!(
Expand All @@ -67,7 +69,7 @@ mod tests {
}
println!();

println!("Tokenization of `{text}`:\n");
println!("Encoding using last tokens of `{text}`:\n");
let mut remaining = text.len();
while remaining > 0 {
let prefix = &text[..remaining];
Expand Down

0 comments on commit 5b127c9

Please sign in to comment.