From a9d7a5b5b6c6fdaa40dddd763b381799a3b1f338 Mon Sep 17 00:00:00 2001 From: Gautier Dagan Date: Fri, 23 Aug 2024 11:33:21 +0100 Subject: [PATCH] perf: filter out length 1 byte sentences during pre-tokenization step --- Cargo.toml | 2 +- benchmarks/README.md | 4 ++-- src/lib.rs | 22 +++++++++++++++++----- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c26211a..81f4a03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bpeasy" -version = "0.1.2" +version = "0.1.3" edition = "2021" [lib] diff --git a/benchmarks/README.md b/benchmarks/README.md index 8984b78..8eabaae 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -4,8 +4,8 @@ Using varying vocab sizes from (5k:100k) | Library/Operation | Time (seconds) | Standard Deviation | |----------------------------|---------------------------------|--------------------------------| -| HuggingFace Train | 0.7369 | ±1.55 | -| `bpeasy` Train | 0.6528 | ±0.386 | +| HuggingFace Train | 0.8165 | ±0.62 | +| `bpeasy` Train | 0.68815 | ±0.41 | | HuggingFace Encode | 0.6247 | ±0.051 | | `bpeasy` Encode (uses `tiktoken`) | 0.2679 | ±0.035 | diff --git a/src/lib.rs b/src/lib.rs index a56c78c..1d3efa6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -146,6 +146,7 @@ fn pretokenize<'a>(text: &'a str, regex: &Regex) -> Vec<&'a str> { fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec) { let regex: Regex = Regex::new(pattern).expect("Invalid regex pattern"); + // Tokenize strings in parallel let (tokens, counts): (Vec<&str>, Vec) = strings .par_iter() .flat_map(|&text| pretokenize(text, ®ex)) @@ -168,8 +169,15 @@ fn pretokenize_strings(strings: Vec<&str>, pattern: &str) -> (Vec, Vec .into_iter() .unzip(); - let sentences: Vec = tokens.into_iter().map(Sentence::from_str).collect(); - (sentences, counts) + // Convert tokens to sentences and filter sentences and counts to remove single byte sentences + let (filtered_sentences, filtered_counts): (Vec, Vec) = tokens + .into_iter() + .map(Sentence::from_str) + .zip(counts.into_iter()) + .filter(|(sentence, _)| sentence.symbols.len() > 1) + .unzip(); + + (filtered_sentences, filtered_counts) } fn initialize_vocab_bytes(vocab_size: usize) -> (HashMap, u32>, Vec>) { @@ -412,11 +420,15 @@ fn bpeasy(_py: Python<'_>, m: &PyModule) -> PyResult<()> { mod tests { #[test] fn test_all() { - let text: &str = "\tYou hear £ £ £ here"; + let text: &str = "\tYou hear a £ £ £ here"; let pattern = r"([^\s]+)|(\s+)"; - let compiled_regex = fancy_regex::Regex::new(pattern).expect("Invalid regex pattern"); + let compiled_regex: fancy_regex::Regex = + fancy_regex::Regex::new(pattern).expect("Invalid regex pattern"); let pretokenized_sentences = crate::pretokenize(text, &compiled_regex); - println!("{:?}", pretokenized_sentences); + assert_eq!( + pretokenized_sentences, + vec!["\t", "You", " ", "hear", " ", "a", " ", "£", " ", "£", " ", "£", " ", "here"] + ); let text_2: &str = "You hear £ £ £ here";