diff --git a/crates/bpe-openai/Cargo.toml b/crates/bpe-openai/Cargo.toml index b4379c3..ea9ea21 100644 --- a/crates/bpe-openai/Cargo.toml +++ b/crates/bpe-openai/Cargo.toml @@ -17,6 +17,7 @@ bpe = { version = "0.1.0", path = "../bpe" } either = "1.13" fancy-regex = "0.13" rmp-serde = "1" +regex = "1" [dev-dependencies] tiktoken-rs = "0.6" diff --git a/crates/bpe-openai/src/lib.rs b/crates/bpe-openai/src/lib.rs index fd2c7c8..92b6c64 100644 --- a/crates/bpe-openai/src/lib.rs +++ b/crates/bpe-openai/src/lib.rs @@ -1,28 +1,58 @@ +use std::ops::Range; use std::sync::LazyLock; use bpe::byte_pair_encoding::BytePairEncoding; use either::Either; use fancy_regex::Regex; +pub use bpe::*; + static BPE_R50K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat = [ + "(?:'s|'t|'re|'ve|'m|'ll|'d)", + " ?\\p{L}+", + " ?\\p{N}+", + " ?[^\\s\\p{L}\\p{N}]+", + "\\s+", // "(:?\\s+(?!\\S)|\\s+)", + ] + .join("|"); + let pre = + Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_whitespace).expect("valid regex"); + Tokenizer::new(bpe, Some(pre)) }); static BPE_P50K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat = [ + "(?:'s|'t|'re|'ve|'m|'ll|'d)", + " ?\\p{L}+", + " ?\\p{N}+", + " ?[^\\s\\p{L}\\p{N}]+", + "\\s+", // "(:?\\s+(?!\\S)|\\s+)", + ] + .join("|"); + let pre = + Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_whitespace).expect("valid regex"); + Tokenizer::new(bpe, Some(pre)) }); static BPE_CL100K_BASE: LazyLock = LazyLock::new(|| { let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict")); let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data"); - let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - Tokenizer::new(bpe, Some(pat)).expect("valid regex") + let pat = [ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)", + "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+", + "\\p{N}{1,3}", + " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*", + "\\s+", // "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)", + ] + .join("|"); + let pre = Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_nonnewline_whitespace) + .expect("valid regex"); + Tokenizer::new(bpe, Some(pre)) }); static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { @@ -33,15 +63,13 @@ static BPE_O200K_BASE: LazyLock = LazyLock::new(|| { "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?", "\\p{N}{1,3}", " ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*", - "\\s*[\\r\\n]+", - "\\s+(?!\\S)", - "\\s+", + "\\s+", // "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)", ].join("|"); - Tokenizer::new(bpe, Some(&pat)).expect("valid regex") + let pre = Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_nonnewline_whitespace) + .expect("valid regex"); + Tokenizer::new(bpe, Some(pre)) }); -pub use bpe::*; - /// A byte-pair encoding tokenizer that supports a pre-tokenization regex. /// The direct methods on this type pre-tokenize the input text and should /// produce the same output as the tiktoken tokenizers. The type gives access @@ -51,15 +79,23 @@ pub use bpe::*; pub struct Tokenizer { /// The byte-pair encoding for this tokenizer. pub bpe: BytePairEncoding, - /// The pattern regex used to split the input. - pub pat: Option, + /// The pretokenizer used to split the input. + pub pre: Option, +} + +/// A trim function that for the given haystack and match range returns the number of bytes that should +/// be discarded from the end of the match. +pub type Trim = fn(&str, Range) -> usize; + +pub struct Pretokenizer { + pat: Regex, + trim: Option, } impl Tokenizer { #[allow(clippy::result_large_err)] - pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result { - let pat = pat.map(fancy_regex::Regex::new).transpose()?; - Ok(Self { bpe, pat }) + pub fn new(bpe: BytePairEncoding, pre: Option) -> Self { + Self { bpe, pre } } pub fn count(&self, text: &str) -> usize { @@ -79,18 +115,49 @@ impl Tokenizer { } pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { - match &self.pat { - Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| { - let m = m.expect("match succeeded"); - assert_eq!(*start, m.start(), "pattern should match all input text"); - *start = m.end(); - Some(m.as_str()) - })), + match &self.pre { + Some(pre) => Either::Left(pre.split(text)), None => Either::Right(std::iter::once(text)), } } } +impl Pretokenizer { + #[allow(clippy::result_large_err)] + pub fn from_pat(pat: &str) -> fancy_regex::Result { + Ok(Self { + pat: Regex::new(pat)?, + trim: None, + }) + } + + #[allow(clippy::result_large_err)] + pub fn from_pat_and_trim(pat: &str, trim: Trim) -> fancy_regex::Result { + Ok(Self { + pat: Regex::new(pat)?, + trim: Some(trim), + }) + } + + pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator + 'a { + let mut start = 0; + std::iter::from_fn(move || { + self.pat + .find_from_pos(text, start) + .expect("can search from position") + .map(|m| { + let mut range = m.range(); + if let Some(trim) = self.trim { + range.end -= trim(text, range.clone()); + } + assert!(range.end > start); + start = range.end; + &text[range] + }) + }) + } +} + pub fn r50k_base() -> &'static Tokenizer { &BPE_R50K_BASE } @@ -107,6 +174,36 @@ pub fn o200k_base() -> &'static Tokenizer { &BPE_O200K_BASE } +/// Allows using `\\s+` instead of `(:?\\s+(?!\\S)|\\s+)`. +/// Assumes no other patterns match whitespace at the end. +fn openai_trim_one_whitespace(text: &str, range: Range) -> usize { + if range.end == text.len() { + return 0; + } + let mut chars = text[range].chars(); + match chars.next_back() { + Some(c) if c.is_whitespace() && chars.next_back().is_some() => c.len_utf8(), + _ => 0, + } +} + +/// Allows using `\\s+` instead of `(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)`. +/// Assumes no other patterns match non-[\r\n] whitespace at the end. +fn openai_trim_one_nonnewline_whitespace(text: &str, range: Range) -> usize { + if range.end == text.len() { + return 0; + } + let mut chars = text[range].chars(); + match chars.next_back() { + Some(c) + if c.is_whitespace() && !matches!(c, '\r' | '\n') && chars.next_back().is_some() => + { + c.len_utf8() + } + _ => 0, + } +} + #[cfg(test)] mod tests { use tiktoken_rs::cl100k_base_singleton; diff --git a/crates/bpe/benchmarks/equivalence.rs b/crates/bpe/benchmarks/equivalence.rs index 7c71e4e..03f71f2 100644 --- a/crates/bpe/benchmarks/equivalence.rs +++ b/crates/bpe/benchmarks/equivalence.rs @@ -1,7 +1,8 @@ use bpe_benchmarks::*; +use bpe_openai::{cl100k_base, Pretokenizer}; #[cfg(test)] -const N: usize = 32; +const N: usize = 128; #[test] fn test_encoding_equivalence_without_pretokenization() { @@ -42,7 +43,7 @@ fn test_encoding_equivalence_without_pretokenization() { #[test] fn test_encoding_equivalence_with_pretokenization() { - for (_, bpe, tiktoken, huggingface) in TOKENIZERS.iter() { + for (name, bpe, tiktoken, huggingface) in [&TOKENIZERS[0]] { let text = create_test_string(&bpe.bpe, 20000); let inputs = (0..N) .map(|_| select_test_bytes(text.as_bytes(), 100)) @@ -64,12 +65,12 @@ fn test_encoding_equivalence_with_pretokenization() { let huggingface_text = huggingface.decode(&huggingface_out, true).unwrap(); if tiktoken_text != huggingface_text { panic!( - "huggingface tokens and text differ: {:?} != {:?}", + "{name}: huggingface tokens and text differ: {:?} != {:?}", huggingface_text, tiktoken_text ); } else { panic!( - "huggingface tokens differ: {:?} != {:?}", + "{name}: huggingface tokens differ: {:?} != {:?}", huggingface_out, tiktoken_out2 ); } @@ -78,13 +79,44 @@ fn test_encoding_equivalence_with_pretokenization() { let text = bpe.decode(&out).unwrap(); if tiktoken_text != text { panic!( - "bpe tokens and text differ: {:?} != {:?}", + "{name}: bpe tokens and text differ: {:?} != {:?}", text, tiktoken_text ); } else { - panic!("bpe tokens differ: {:?} != {:?}", out, tiktoken_out2); + panic!( + "{name}: bpe tokens differ: {:?} != {:?}", + out, tiktoken_out2 + ); } } } } } + +#[test] +fn test_pretokenization() { + let fast = cl100k_base().pre.as_ref().unwrap(); + + let slow_pat = [ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)", + "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+", + "\\p{N}{1,3}", + " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*", + "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)", + ] + .join("|"); + let slow = Pretokenizer::from_pat(&slow_pat).unwrap(); + + let text = create_test_string(&cl100k_base().bpe, 20000); + let inputs = (0..N) + .map(|_| select_test_bytes(text.as_bytes(), 100)) + .chain(std::iter::once( + "You should see the Greek word 'kosme': \"κόσμε\"".as_bytes(), + )); + for input in inputs { + let text = std::str::from_utf8(input).unwrap(); + let slow_out: Vec<_> = slow.split(text).collect(); + let fast_out: Vec<_> = fast.split(text).collect(); + assert_eq!(slow_out, fast_out); + } +} diff --git a/crates/bpe/benchmarks/performance.rs b/crates/bpe/benchmarks/performance.rs index 4ec973e..6df3c8a 100644 --- a/crates/bpe/benchmarks/performance.rs +++ b/crates/bpe/benchmarks/performance.rs @@ -3,6 +3,7 @@ use std::time::Duration; use bpe::appendable_encoder::AppendableEncoder; use bpe::interval_encoding::IntervalEncoding; use bpe_benchmarks::*; +use bpe_openai::{cl100k_base, Pretokenizer}; use bpe_tests::create_test_bytes; use criterion::{ criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration, @@ -228,12 +229,49 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) { } } +fn pretok_benchmark(c: &mut Criterion) { + let fast = cl100k_base().pre.as_ref().unwrap(); + + let slow_pat = [ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)", + "[^\\r\\n\\p{L}\\p{N}]?\\p{L}+", + "\\p{N}{1,3}", + " ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*", + "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)", + ] + .join("|"); + let slow = Pretokenizer::from_pat(&slow_pat).unwrap(); + + let text = create_test_string(&cl100k_base().bpe, 20000); + let input = text.as_bytes(); + + let mut group = c.benchmark_group(format!("pretok-cl100k")); + for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] { + group.throughput(criterion::Throughput::Bytes(bytes as u64)); + group.bench_with_input(BenchmarkId::new("fast", bytes), &bytes, |b, bytes| { + b.iter_batched( + || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + |text| fast.split(text).count(), + criterion::BatchSize::SmallInput, + ) + }); + group.bench_with_input(BenchmarkId::new("slow", bytes), &bytes, |b, bytes| { + b.iter_batched( + || std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(), + |text| slow.split(text).count(), + criterion::BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + criterion_group!( name = benches; config = Criterion::default() .warm_up_time(Duration::from_millis(500)) .measurement_time(Duration::from_millis(4000)) .nresamples(1000); - targets = counting_benchmark, encoding_benchmark, appending_benchmark, comparison_benchmark, worstcase_comparison_benchmark + targets = counting_benchmark, encoding_benchmark, appending_benchmark, comparison_benchmark, worstcase_comparison_benchmark, pretok_benchmark ); criterion_main!(benches); diff --git a/crates/bpe/src/byte_pair_encoding.rs b/crates/bpe/src/byte_pair_encoding.rs index eede9fd..a28db64 100644 --- a/crates/bpe/src/byte_pair_encoding.rs +++ b/crates/bpe/src/byte_pair_encoding.rs @@ -15,7 +15,7 @@ use crate::bitfield::BitField; /// Representation of the byte pair dictionary. /// This struct provides various conversions. /// We put all of them into a single struct so that they can be reused by different implementations. -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct BytePairEncoding { /// All the decoded tokens concatenated into all_tokens: Vec,