diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index 9efbb0e..eede9fd 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -303,7 +303,7 @@ impl BytePairEncoding { split_table.push((id as u32, id as u32)); } } - Self { + let bpe = Self { all_tokens, token_starts, bytes_hash_to_token, @@ -314,7 +314,17 @@ impl BytePairEncoding { pair_lookup, split_table, hash_factor, + }; + for token_id in 0..bpe.num_tokens() as u32 { + let bytes = bpe.token_bytes(token_id); + let tokens = bpe.encode_via_bitfield(bytes); + assert_eq!( + tokens, + vec![token_id], + "token {token_id} with bytes {bytes:?} encodes to {tokens:?} instead of to itself" + ); } + bpe } /// Return the number of tokens in this BPE dictionary. diff --git a/crates/bpe/tests/src/lib.rs b/crates/bpe/tests/src/lib.rs index ed2ab81..9c02773 100644 --- a/crates/bpe/tests/src/lib.rs +++ b/crates/bpe/tests/src/lib.rs @@ -30,16 +30,16 @@ mod tests { /// This test produces the output for the encoding example in the README. #[test] fn readme_example() { - let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec()); - let bpe = BytePairEncoding::from_dictionary(tokens, None); - let text = "abacb"; + let tokens = ["a", "b", "c", "ab", "cb", "ac", "bb", "cbb", "acbb"]; + let bpe = BytePairEncoding::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let text = "abacbb"; let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec(); let all_prefix_tokens = prefixes .iter() .map(|prefix| { bpe.encode_via_backtracking(prefix.as_bytes()) .into_iter() - .map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) }) + .map(|t| String::from_utf8(bpe.decode_tokens(&[t])).unwrap()) .collect_vec() }) .collect_vec(); @@ -48,6 +48,8 @@ mod tests { .map(|tokens| tokens.last().unwrap()) .collect_vec(); + println!("Token set: `{}`\n", tokens.join(" ")); + println!("All tokens for each prefix of `{text}`:\n"); for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) { println!( @@ -67,7 +69,7 @@ mod tests { } println!(); - println!("Tokenization of `{text}`:\n"); + println!("Encoding using last tokens of `{text}`:\n"); let mut remaining = text.len(); while remaining > 0 { let prefix = &text[..remaining];