Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce Pretokenizer to allow eliminating look-ahead in the regex
Browse files Browse the repository at this point in the history
hendrikvanantwerpen committed Oct 17, 2024
1 parent 5b127c9 commit 3f40ce9
Showing 5 changed files with 200 additions and 32 deletions.
1 change: 1 addition & 0 deletions crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ bpe = { version = "0.1.0", path = "../bpe" }
either = "1.13"
fancy-regex = "0.13"
rmp-serde = "1"
regex = "1"

[dev-dependencies]
tiktoken-rs = "0.6"
145 changes: 121 additions & 24 deletions crates/bpe-openai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
use std::ops::Range;
use std::sync::LazyLock;

use bpe::byte_pair_encoding::BytePairEncoding;
use either::Either;
use fancy_regex::Regex;

pub use bpe::*;

static BPE_R50K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_r50k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
let pat = [
"(?:'s|'t|'re|'ve|'m|'ll|'d)",
" ?\\p{L}+",
" ?\\p{N}+",
" ?[^\\s\\p{L}\\p{N}]+",
"\\s+", // "(:?\\s+(?!\\S)|\\s+)",
]
.join("|");
let pre =
Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_whitespace).expect("valid regex");
Tokenizer::new(bpe, Some(pre))
});

static BPE_P50K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_p50k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
let pat = [
"(?:'s|'t|'re|'ve|'m|'ll|'d)",
" ?\\p{L}+",
" ?\\p{N}+",
" ?[^\\s\\p{L}\\p{N}]+",
"\\s+", // "(:?\\s+(?!\\S)|\\s+)",
]
.join("|");
let pre =
Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_whitespace).expect("valid regex");
Tokenizer::new(bpe, Some(pre))
});

static BPE_CL100K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k_base.dict"));
let bpe = rmp_serde::from_slice(bytes).expect("valid bpe data");
let pat = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
Tokenizer::new(bpe, Some(pat)).expect("valid regex")
let pat = [
"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
"[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
"\\s+", // "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)",
]
.join("|");
let pre = Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_nonnewline_whitespace)
.expect("valid regex");
Tokenizer::new(bpe, Some(pre))
});

static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
@@ -33,15 +63,13 @@ static BPE_O200K_BASE: LazyLock<Tokenizer> = LazyLock::new(|| {
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
"\\s*[\\r\\n]+",
"\\s+(?!\\S)",
"\\s+",
"\\s+", // "(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)",
].join("|");
Tokenizer::new(bpe, Some(&pat)).expect("valid regex")
let pre = Pretokenizer::from_pat_and_trim(&pat, openai_trim_one_nonnewline_whitespace)
.expect("valid regex");
Tokenizer::new(bpe, Some(pre))
});

pub use bpe::*;

/// A byte-pair encoding tokenizer that supports a pre-tokenization regex.
/// The direct methods on this type pre-tokenize the input text and should
/// produce the same output as the tiktoken tokenizers. The type gives access
@@ -51,15 +79,23 @@ pub use bpe::*;
pub struct Tokenizer {
/// The byte-pair encoding for this tokenizer.
pub bpe: BytePairEncoding,
/// The pattern regex used to split the input.
pub pat: Option<Regex>,
/// The pretokenizer used to split the input.
pub pre: Option<Pretokenizer>,
}

/// A trim function that for the given haystack and match range returns the number of bytes that should
/// be discarded from the end of the match.
pub type Trim = fn(&str, Range<usize>) -> usize;

pub struct Pretokenizer {
pat: Regex,
trim: Option<Trim>,
}

impl Tokenizer {
#[allow(clippy::result_large_err)]
pub fn new(bpe: BytePairEncoding, pat: Option<&str>) -> fancy_regex::Result<Self> {
let pat = pat.map(fancy_regex::Regex::new).transpose()?;
Ok(Self { bpe, pat })
pub fn new(bpe: BytePairEncoding, pre: Option<Pretokenizer>) -> Self {
Self { bpe, pre }
}

pub fn count(&self, text: &str) -> usize {
@@ -79,18 +115,49 @@ impl Tokenizer {
}

pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
match &self.pat {
Some(pat) => Either::Left(pat.find_iter(text).scan(0, |start, m| {
let m = m.expect("match succeeded");
assert_eq!(*start, m.start(), "pattern should match all input text");
*start = m.end();
Some(m.as_str())
})),
match &self.pre {
Some(pre) => Either::Left(pre.split(text)),
None => Either::Right(std::iter::once(text)),
}
}
}

impl Pretokenizer {
#[allow(clippy::result_large_err)]
pub fn from_pat(pat: &str) -> fancy_regex::Result<Self> {
Ok(Self {
pat: Regex::new(pat)?,
trim: None,
})
}

#[allow(clippy::result_large_err)]
pub fn from_pat_and_trim(pat: &str, trim: Trim) -> fancy_regex::Result<Self> {
Ok(Self {
pat: Regex::new(pat)?,
trim: Some(trim),
})
}

pub fn split<'a>(&'a self, text: &'a str) -> impl Iterator<Item = &str> + 'a {
let mut start = 0;
std::iter::from_fn(move || {
self.pat
.find_from_pos(text, start)
.expect("can search from position")
.map(|m| {
let mut range = m.range();
if let Some(trim) = self.trim {
range.end -= trim(text, range.clone());
}
assert!(range.end > start);
start = range.end;
&text[range]
})
})
}
}

pub fn r50k_base() -> &'static Tokenizer {
&BPE_R50K_BASE
}
@@ -107,6 +174,36 @@ pub fn o200k_base() -> &'static Tokenizer {
&BPE_O200K_BASE
}

/// Allows using `\\s+` instead of `(:?\\s+(?!\\S)|\\s+)`.
/// Assumes no other patterns match whitespace at the end.
fn openai_trim_one_whitespace(text: &str, range: Range<usize>) -> usize {
if range.end == text.len() {
return 0;
}
let mut chars = text[range].chars();
match chars.next_back() {
Some(c) if c.is_whitespace() && chars.next_back().is_some() => c.len_utf8(),
_ => 0,
}
}

/// Allows using `\\s+` instead of `(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)`.
/// Assumes no other patterns match non-[\r\n] whitespace at the end.
fn openai_trim_one_nonnewline_whitespace(text: &str, range: Range<usize>) -> usize {
if range.end == text.len() {
return 0;
}
let mut chars = text[range].chars();
match chars.next_back() {
Some(c)
if c.is_whitespace() && !matches!(c, '\r' | '\n') && chars.next_back().is_some() =>
{
c.len_utf8()
}
_ => 0,
}
}

#[cfg(test)]
mod tests {
use tiktoken_rs::cl100k_base_singleton;
44 changes: 38 additions & 6 deletions crates/bpe/benchmarks/equivalence.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use bpe_benchmarks::*;
use bpe_openai::{cl100k_base, Pretokenizer};

#[cfg(test)]
const N: usize = 32;
const N: usize = 128;

#[test]
fn test_encoding_equivalence_without_pretokenization() {
@@ -42,7 +43,7 @@ fn test_encoding_equivalence_without_pretokenization() {

#[test]
fn test_encoding_equivalence_with_pretokenization() {
for (_, bpe, tiktoken, huggingface) in TOKENIZERS.iter() {
for (name, bpe, tiktoken, huggingface) in [&TOKENIZERS[0]] {
let text = create_test_string(&bpe.bpe, 20000);
let inputs = (0..N)
.map(|_| select_test_bytes(text.as_bytes(), 100))
@@ -64,12 +65,12 @@ fn test_encoding_equivalence_with_pretokenization() {
let huggingface_text = huggingface.decode(&huggingface_out, true).unwrap();
if tiktoken_text != huggingface_text {
panic!(
"huggingface tokens and text differ: {:?} != {:?}",
"{name}: huggingface tokens and text differ: {:?} != {:?}",
huggingface_text, tiktoken_text
);
} else {
panic!(
"huggingface tokens differ: {:?} != {:?}",
"{name}: huggingface tokens differ: {:?} != {:?}",
huggingface_out, tiktoken_out2
);
}
@@ -78,13 +79,44 @@ fn test_encoding_equivalence_with_pretokenization() {
let text = bpe.decode(&out).unwrap();
if tiktoken_text != text {
panic!(
"bpe tokens and text differ: {:?} != {:?}",
"{name}: bpe tokens and text differ: {:?} != {:?}",
text, tiktoken_text
);
} else {
panic!("bpe tokens differ: {:?} != {:?}", out, tiktoken_out2);
panic!(
"{name}: bpe tokens differ: {:?} != {:?}",
out, tiktoken_out2
);
}
}
}
}
}

#[test]
fn test_pretokenization() {
let fast = cl100k_base().pre.as_ref().unwrap();

let slow_pat = [
"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
"[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
"(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)",
]
.join("|");
let slow = Pretokenizer::from_pat(&slow_pat).unwrap();

let text = create_test_string(&cl100k_base().bpe, 20000);
let inputs = (0..N)
.map(|_| select_test_bytes(text.as_bytes(), 100))
.chain(std::iter::once(
"You should see the Greek word 'kosme': \"κόσμε\"".as_bytes(),
));
for input in inputs {
let text = std::str::from_utf8(input).unwrap();
let slow_out: Vec<_> = slow.split(text).collect();
let fast_out: Vec<_> = fast.split(text).collect();
assert_eq!(slow_out, fast_out);
}
}
40 changes: 39 additions & 1 deletion crates/bpe/benchmarks/performance.rs
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ use std::time::Duration;
use bpe::appendable_encoder::AppendableEncoder;
use bpe::interval_encoding::IntervalEncoding;
use bpe_benchmarks::*;
use bpe_openai::{cl100k_base, Pretokenizer};
use bpe_tests::create_test_bytes;
use criterion::{
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
@@ -228,12 +229,49 @@ fn worstcase_comparison_benchmark(c: &mut Criterion) {
}
}

fn pretok_benchmark(c: &mut Criterion) {
let fast = cl100k_base().pre.as_ref().unwrap();

let slow_pat = [
"(?i:'s|'t|'re|'ve|'m|'ll|'d)",
"[^\\r\\n\\p{L}\\p{N}]?\\p{L}+",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*",
"(?:\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+)",
]
.join("|");
let slow = Pretokenizer::from_pat(&slow_pat).unwrap();

let text = create_test_string(&cl100k_base().bpe, 20000);
let input = text.as_bytes();

let mut group = c.benchmark_group(format!("pretok-cl100k"));
for bytes in [10, 100, 1000, 5000, 10000, 25000, 50000, 75000, 100000] {
group.throughput(criterion::Throughput::Bytes(bytes as u64));
group.bench_with_input(BenchmarkId::new("fast", bytes), &bytes, |b, bytes| {
b.iter_batched(
|| std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(),
|text| fast.split(text).count(),
criterion::BatchSize::SmallInput,
)
});
group.bench_with_input(BenchmarkId::new("slow", bytes), &bytes, |b, bytes| {
b.iter_batched(
|| std::str::from_utf8(select_test_bytes(input, *bytes)).unwrap(),
|text| slow.split(text).count(),
criterion::BatchSize::SmallInput,
)
});
}
group.finish();
}

criterion_group!(
name = benches;
config = Criterion::default()
.warm_up_time(Duration::from_millis(500))
.measurement_time(Duration::from_millis(4000))
.nresamples(1000);
targets = counting_benchmark, encoding_benchmark, appending_benchmark, comparison_benchmark, worstcase_comparison_benchmark
targets = counting_benchmark, encoding_benchmark, appending_benchmark, comparison_benchmark, worstcase_comparison_benchmark, pretok_benchmark
);
criterion_main!(benches);
2 changes: 1 addition & 1 deletion crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ use crate::bitfield::BitField;
/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
#[derive(Serialize, Deserialize)]
#[derive(Clone, Serialize, Deserialize)]
pub struct BytePairEncoding {
/// All the decoded tokens concatenated into
all_tokens: Vec<u8>,

0 comments on commit 3f40ce9

Please sign in to comment.