diff --git a/crates/bpe/README.md b/crates/bpe/README.md index 404e389..d083fd4 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -96,31 +96,32 @@ Given a valid encoding sequence `e_0..e_i` and a valid encoding tuple `e_i e_j`, ## Novel Algorithm At a first glance, it seems impossible to achieve `O(n)` complexity while preserving the encoding output of the original BPE algorithm, since the original BPE algorithm needs to first scan the full input before it can make any encoding decision. -For instance, the sequence `abac` would be encoded as `ab ac` when the dictionary contains the tokens `a b c ab cb ac` ordered by frequency. But appending a single character `abacb` would result in a pretty different tokenization: `ab a cb`. So without looking ahead it seems impossible to properly tokenize the text. +For instance, the sequence `abacb` would be encoded as `ab a cb` when the dictionary contains the tokens `a b c ab cb ac bb cbb acbb` ordered by frequency. But appending a single character `abacbb` would result in a pretty different tokenization: `ab acbb`. So without looking ahead it seems impossible to properly tokenize the text. -The solution is to track the encodings of ALL text prefixes. For our example `abacb` we would get: +The solution is to track the encodings of ALL text prefixes. For our example `abacbb` we would get: -- `a` ------> `a` -- `ab` -----> `ab` -- `aba` ----> `ab a` -- `abac` ---> `ab ac` -- `abacb` --> `ab a cb` +- `a` -------> `a` +- `ab` ------> `ab` +- `aba` -----> `ab a` +- `abac` ----> `ab ac` +- `abacb` ---> `ab a cb` +- `abacbb` --> `ab acbb` This can be done much more efficiently thanks to Corollary IIa, since now only the last token of every prefix has to be remembered: -- `a` ------> `a` -- `ab` -----> `ab` -- `aba` ----> `a` -- `abac` ---> `ac` -- `abacb` --> `cb` +- `a` -------> `a` +- `ab` ------> `ab` +- `aba` -----> `a` +- `abac` ----> `ac` +- `abacb` ---> `cb` +- `abacbb` --> `acbb` In order to reconstruct the full encoding for a specific prefix, one simply starts with the last token of that prefix, shortens the prefix by the extracted token and looks up the token associated with the shortened prefix and so on until the beginning of the text is reached. -For our example prefix `abacb`, this procedure executes the following steps and determines the correct encoding in reverse order: +For our example prefix `abacbb`, this procedure executes the following steps and determines the correct encoding in reverse order: -- `abacb` -> `cb` -- `aba` ---> `a` -- `ab` ----> `ab` +- `abacbb` --> `acbb` +- `ab` ------> `ab` - `` The actual challenge is to determine for every prefix this last token efficiently. diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 9efbb0e..eede9fd 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -303,7 +303,7 @@ impl BytePairEncoding { split_table.push((id as u32, id as u32)); } } - Self { + let bpe = Self { all_tokens, token_starts, bytes_hash_to_token, @@ -314,7 +314,17 @@ impl BytePairEncoding { pair_lookup, split_table, hash_factor, + }; + for token_id in 0..bpe.num_tokens() as u32 { + let bytes = bpe.token_bytes(token_id); + let tokens = bpe.encode_via_bitfield(bytes); + assert_eq!( + tokens, + vec![token_id], + "token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself" + ); } + bpe } /// Return the number of tokens in this BPE dictionary. diff --git a/crates/bpe/tests/src/lib.rs b/crates/bpe/tests/src/lib.rs index ed2ab81..9c02773 100644 --- a/crates/bpe/tests/src/lib.rs +++ b/crates/bpe/tests/src/lib.rs @@ -30,16 +30,16 @@ mod tests { /// This test produces the output for the encoding example in the README. #[test] fn readme_example() { - let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec()); - let bpe = BytePairEncoding::from_dictionary(tokens, None); - let text = "abacb"; + let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"]; + let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let text = "abacbb"; let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec(); let all_prefix_tokens = prefixes .iter() .map(|prefix| { bpe.encode_via_backtracking(prefix.as_bytes()) .into_iter() - .map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) }) + .map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap()) .collect_vec() }) .collect_vec(); @@ -48,6 +48,8 @@ mod tests { .map(|tokens| tokens.last().unwrap()) .collect_vec(); + println!("Token set: `{}`\n", tokens.join(" ")); + println!("All tokens for each prefix of `{text}`:\n"); for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) { println!( @@ -67,7 +69,7 @@ mod tests { } println!(); - println!("Tokenization of `{text}`:\n"); + println!("Encoding using last tokens of `{text}`:\n"); let mut remaining = text.len(); while remaining > 0 { let prefix = &text[..remaining];