diff --git a/crates/bpe/README.md b/crates/bpe/README.md index 404e389..b0fc43c 100644 --- a/crates/bpe/README.md +++ b/crates/bpe/README.md @@ -290,9 +290,8 @@ This suggests that pre-tokenization is not necessary from a performance perspect ![encoding runtime comparison](./images/performance-comparison.svg) -The graph below shows encoding results for input that is particularly challenging for tiktoken. -The input consists of random ranges taken from the continuous list of all Unicode code points excluding whitespace. -The performance of tiktoken shows a quadratic growth with the input size. +The graph below shows encoding results when the input cannot be split in pre-tokenization and allows a better comparison of pure BPE performance. +This case is particularly challenging for tiktoken, which shows a quadratic growth with the input size. The Huggingface encoder scales better, but becomes slower and slower compared to our implementation as input size increases. ![worst-case encoding runtime comparison](./images/performance-worstcase.svg) diff --git a/crates/bpe/benchmarks/equivalence.rs b/crates/bpe/benchmarks/equivalence.rs index 7c71e4e..3d59d3b 100644 --- a/crates/bpe/benchmarks/equivalence.rs +++ b/crates/bpe/benchmarks/equivalence.rs @@ -7,7 +7,7 @@ const N: usize = 32; fn test_encoding_equivalence_without_pretokenization() { for (_, bpe, _, huggingface) in TOKENIZERS.iter() { let huggingface = without_pretokenizer(huggingface); - let text = create_test_string(&bpe.bpe, 20000); + let text = create_test_string(bpe, 20000, true); let inputs = (0..N) .map(|_| select_test_bytes(text.as_bytes(), 100)) .chain(std::iter::once( @@ -43,7 +43,7 @@ fn test_encoding_equivalence_without_pretokenization() { #[test] fn test_encoding_equivalence_with_pretokenization() { for (_, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text = create_test_string(&bpe.bpe, 20000); + let text = create_test_string(bpe, 20000, true); let inputs = (0..N) .map(|_| select_test_bytes(text.as_bytes(), 100)) .chain(std::iter::once( diff --git a/crates/bpe/benchmarks/lib.rs b/crates/bpe/benchmarks/lib.rs index f260ebd..f98aab5 100644 --- a/crates/bpe/benchmarks/lib.rs +++ b/crates/bpe/benchmarks/lib.rs @@ -1,6 +1,5 @@ use std::sync::LazyLock; -use bpe::byte_pair_encoding::BytePairEncoding; use bpe_openai::Tokenizer; use rand::{thread_rng, Rng}; use tiktoken_rs::CoreBPE as TiktokenTokenizer; @@ -41,19 +40,38 @@ pub fn is_char_boundary(b: u8) -> bool { b as i8 >= -0x40 // NB: b < 128 || b >= 192 } -pub fn create_test_string(bpe: &BytePairEncoding, tokens: usize) -> String { +/// Create a test string from the given number of random tokens. Note that re-tokenizing the string +/// may result in a different token count! It is possible to request a string that cannot be split +/// with the tokenizers regex. Be aware that generating the string is slow in that case. +pub fn create_test_string(tok: &Tokenizer, tokens: usize, allow_splits: bool) -> String { use rand::{thread_rng, Rng}; let mut text = String::new(); - for _ in 0..tokens { - loop { - let i = thread_rng().gen_range(0..bpe.num_tokens()); - let s = bpe.token_bytes(i as u32); - if s.iter().all(|b| is_char_boundary(*b)) { - if let Ok(s) = std::str::from_utf8(s) { - text.push_str(s); - break; + let mut text_len = Vec::new(); + 'next_token: while text_len.len() < tokens { + // try a few of times to find a token + for _ in 0..8 { + // ensure the token results in a valid string + loop { + let i = thread_rng().gen_range(0..tok.bpe.num_tokens()); + let s = tok.bpe.token_bytes(i as u32); + if s.iter().all(|b| is_char_boundary(*b)) { + if let Ok(s) = std::str::from_utf8(s) { + text_len.push(text.len()); + text.push_str(s); + break; + } } } + // if splits are allowed, or there are not splits, add the next token, otherwise drop the token and retry + if allow_splits || tok.split(&text).nth(1).is_none() { + continue 'next_token; + } else { + text.truncate(text_len.pop().expect("we just pushed a token")); + } + } + // we failed to find a token that doesn't result in a split, we backtrack to try different combinations + if let Some(len) = text_len.pop() { + text.truncate(len) } } text diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index 4ec973e..c3225ff 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -45,7 +45,7 @@ fn encoding_benchmark(c: &mut Criterion) { for (name, bpe, _, huggingface) in TOKENIZERS.iter() { let huggingface = without_pretokenizer(huggingface); - let text = create_test_string(&bpe.bpe, 20000); + let text = create_test_string(bpe, 20000, true); let input = text.as_bytes(); let mut group = c.benchmark_group(format!("encoding-{name}")); @@ -145,7 +145,7 @@ fn appending_benchmark(c: &mut Criterion) { fn comparison_benchmark(c: &mut Criterion) { for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text = create_test_string(&bpe.bpe, 20000); + let text = create_test_string(bpe, 20000, true); let input = text.as_bytes(); let mut group = c.benchmark_group(format!("comparison-{name}")); @@ -188,18 +188,23 @@ fn comparison_benchmark(c: &mut Criterion) { fn worstcase_comparison_benchmark(c: &mut Criterion) { for (name, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { - let text: String = ('\0'..char::MAX).filter(|c| !c.is_whitespace()).collect(); + let text = create_test_string(bpe, 20000, false); let input = text.as_bytes(); let mut group = c.benchmark_group(format!("worstcase-{name}")); - for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] { + for bytes in [10, 100, 1000] { //, 5000, 10000, 25000, 50000, 75000, 100000] { group.throughput(criterion::Throughput::Bytes(bytes as u64)); group.bench_with_input( BenchmarkId::new("backtracking", bytes), &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || { + let text = + std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(); + assert!(bpe.split(text).nth(1).is_none()); + text + }, |text| bpe.encode(text), criterion::BatchSize::SmallInput, ) @@ -207,7 +212,11 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { ); group.bench_with_input(BenchmarkId::new("tiktoken", bytes), &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || { + let text = std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(); + assert!(bpe.split(text).nth(1).is_none()); + text + }, |text| tiktoken.encode_ordinary(text), criterion::BatchSize::SmallInput, ) @@ -217,7 +226,12 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { &bytes, |b, bytes| { b.iter_batched( - || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + || { + let text = + std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(); + assert!(bpe.split(text).nth(1).is_none()); + text + }, |text| huggingface.encode_fast(text, false).unwrap(), criterion::BatchSize::SmallInput, ) diff --git a/crates/bpe/images/performance-worstcase.svg b/crates/bpe/images/performance-worstcase.svg index 03f6d3f..c85387f 100644 --- a/crates/bpe/images/performance-worstcase.svg +++ b/crates/bpe/images/performance-worstcase.svg @@ -4,24 +4,30 @@ - - - - - - + + + + + + + + + - - - - - - - + + + + + + + + + + - + @@ -46,40 +52,40 @@ - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - +