Skip to content

Commit

Permalink
Move tests into separate package
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 14, 2024
1 parent 94500bf commit c29a180
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 269 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
members = [
"crates/*",
"crates/bpe/benchmarks",
"crates/bpe/tests",
]
resolver = "2"

Expand Down
2 changes: 1 addition & 1 deletion crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ rmp-serde = "1"
serde = { version = "1" }

[dev-dependencies]
tiktoken-rs = { version = "0.6" }
tiktoken-rs = "0.6"

[build-dependencies]
base64 = "0.22.1"
Expand Down
1 change: 0 additions & 1 deletion crates/bpe-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ mod tests {
.lock()
.encode_ordinary(text)
.into_iter()
.map(|i| i as u32)
.collect();

let without_splitting = BPE_CL100K_BASE.bpe.encode_via_backtracking(input);
Expand Down
5 changes: 2 additions & 3 deletions crates/bpe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ bench = false

[features]
rand = ["dep:rand"]
tiktoken-rs = ["dep:tiktoken-rs"]

[dependencies]
aneubeck-daachorse = "1.1.1"
Expand All @@ -23,7 +22,7 @@ itertools = "0.12"
rand = { version = "0.8", optional = true }
rmp-serde = "1"
serde = { version = "1", features = ["derive"] }
tiktoken-rs = { version = "0.5", optional = true }

[dev-dependencies]
bpe = { path = ".", features = ["rand", "tiktoken-rs"] }
bpe = { path = "." }
tiktoken-rs = "0.6"
5 changes: 3 additions & 2 deletions crates/bpe/benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ path = "equivalence.rs"
test = true

[dependencies]
bpe = { path = "../../bpe", features = ["rand", "tiktoken-rs"] }
bpe = { path = "../../bpe" }
bpe-openai = { path = "../../bpe-openai" }
bpe-tests = { path = "../tests" }
criterion = "0.5"
rand = "0.8"
tiktoken-rs = "0.5"
tiktoken-rs = "0.6"
tokenizers = { version = "0.20", features = ["http"] }
8 changes: 4 additions & 4 deletions crates/bpe/benchmarks/equivalence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn test_encoding_equivalence_without_pretokenization() {
for input in inputs {
let text = std::str::from_utf8(input).unwrap();
let out = bpe.bpe.encode_via_backtracking(input);
let huggingface_out: Vec<_> = huggingface
let huggingface_out = huggingface
.encode_fast(text, false)
.unwrap()
.get_ids()
Expand Down Expand Up @@ -52,10 +52,10 @@ fn test_encoding_equivalence_with_pretokenization() {
for input in inputs {
let text = std::str::from_utf8(input).unwrap();
let out = bpe.encode(text);
let tiktoken_out: Vec<_> = tiktoken.encode_ordinary(text);
let tiktoken_out2: Vec<_> = tiktoken_out.iter().map(|i| *i as u32).collect();
let tiktoken_out = tiktoken.encode_ordinary(text);
let tiktoken_out2 = tiktoken_out.to_vec();
let tiktoken_text = tiktoken.decode(tiktoken_out.clone()).unwrap();
let huggingface_out: Vec<_> = huggingface
let huggingface_out = huggingface
.encode_fast(text, false)
.unwrap()
.get_ids()
Expand Down
2 changes: 1 addition & 1 deletion crates/bpe/benchmarks/performance.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::time::Duration;

use bpe::appendable_encoder::AppendableEncoder;
use bpe::byte_pair_encoding::create_test_bytes;
use bpe::interval_encoding::IntervalEncoding;
use bpe_benchmarks::*;
use bpe_tests::create_test_bytes;
use criterion::{
criterion_group, criterion_main, AxisScale, BenchmarkId, Criterion, PlotConfiguration,
};
Expand Down
18 changes: 0 additions & 18 deletions crates/bpe/src/appendable_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,3 @@ impl<'a> AppendableEncoder<'a> {
self.states.is_empty()
}
}

#[cfg(test)]
mod tests {
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::AppendableEncoder;

#[test]
fn test_appendable_encoder() {
let bpe = &BPE_CL100K;
let mut enc = AppendableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate() {
assert_eq!(enc.token_count(), bpe.count(&input_string[0..i]));
enc.push(*c);
}
}
}
135 changes: 0 additions & 135 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,6 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::backtrack_encoder::BacktrackEncoder;
use crate::bitfield::BitField;

#[cfg(test)]
pub(crate) static BPE_CL100K: std::sync::LazyLock<BytePairEncoding> =
std::sync::LazyLock::new(|| {
BytePairEncoding::from_tiktoken(
&tiktoken_rs::cl100k_base_singleton().lock(),
100256,
Some(17846336922010275747),
)
});

#[cfg(test)]
pub(crate) static BPE_O200K: std::sync::LazyLock<BytePairEncoding> =
std::sync::LazyLock::new(|| {
BytePairEncoding::from_tiktoken(
&tiktoken_rs::o200k_base_singleton().lock(),
199998,
Some(17846336922010275747),
)
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
/// We put all of them into a single struct so that they can be reused by different implementations.
Expand Down Expand Up @@ -175,13 +155,6 @@ fn hash_bytes(bytes: &[u8], factor: u64) -> u32 {
((hasher.finish().wrapping_mul(factor)) >> 32) as u32
}

/// Find a suitable hash factor for the given tiktoken dictionary that prevents collisions
/// when constructing a [`BytePairEncoding`] from those tokens.
#[cfg(all(feature = "tiktoken-rs", feature = "rand"))]
pub fn find_hash_factor_for_tiktoken(bpe: &tiktoken_rs::CoreBPE, len: usize) -> u64 {
find_hash_factor_for_dictionary((0..len).map(|i| bpe._decode_native(&[i])))
}

/// Find a suitable hash factor for a set of given tokens that prevents collisions when
/// constructing a [`BytePairEncoding`] from those tokens.
#[cfg(feature = "rand")]
Expand Down Expand Up @@ -221,24 +194,6 @@ fn find_token_by_bytes(
}

impl BytePairEncoding {
/// Construct a BytePairEncoding instance from a tiktoken dictionary.
/// A suitable hash factor may be necessary to prevent hash collisions,
/// which can by found using [`find_hash_factor_for_tiktoken`].
///
/// The recommended approach is to store the serialized value and reuse that,
/// to prevent repeating the cost of computing the hash factor and encoding.
#[cfg(feature = "tiktoken-rs")]
pub fn from_tiktoken(
tiktoken_bpe: &tiktoken_rs::CoreBPE,
num_tokens: usize,
hash_factor: Option<u64>,
) -> Self {
Self::from_dictionary(
(0..num_tokens).map(|i| tiktoken_bpe._decode_native(&[i])),
hash_factor,
)
}

/// Construct a BytePairEncoding instance from an iterator that enumerates all tokens.
/// A suitable hash factor may be necessary to prevent hash collisions, which can be
/// found using [`find_hash_factor_for_dictionary`].
Expand Down Expand Up @@ -549,93 +504,3 @@ impl BytePairEncoding {
encoded
}
}

#[cfg(feature = "rand")]
pub fn create_test_bytes(bpe: &BytePairEncoding, tokens: usize) -> Vec<u8> {
use rand::{thread_rng, Rng};
let mut text = vec![];
for _ in 0..tokens {
let i = thread_rng().gen_range(0..bpe.num_tokens());
let s = bpe.token_bytes(i as u32);
text.extend_from_slice(s);
}
text
}

#[cfg(test)]
mod tests {

use std::time::Instant;

use itertools::Itertools;
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};

use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K, BPE_O200K};

#[test]
fn test_correctness_cl100k() {
// This is quite a challenging test case...
let test_string = std::str::from_utf8(&[
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
102, 102, 101, 110, 100,
])
.unwrap();
let time = Instant::now();
let bpe = &BPE_CL100K;
println!("{:?}", time.elapsed());
let encoded1 = cl100k_base_singleton()
.lock()
.encode_ordinary(test_string)
.into_iter()
.map(|t| t as u32)
.collect_vec();
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
assert_eq!(encoded1, encoded2);
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
assert_eq!(encoded1, encoded3);
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
assert_eq!(encoded1, encoded4);
}

#[test]
fn test_correctness_o200k() {
// This is quite a challenging test case...
let test_string = std::str::from_utf8(&[
125, 34, 10, 10, 46, 109, 107, 100, 105, 114, 115, 32, 102, 100, 115, 32, 97, 100, 105,
112, 105, 115, 105, 99, 105, 110, 103, 105, 116, 121, 69, 110, 103, 105, 110, 101, 32,
69, 67, 105, 114, 105, 101, 32, 111, 112, 116, 105, 109, 97, 108, 95, 68, 65, 32, 111,
102, 102, 101, 110, 100,
])
.unwrap();
let time = Instant::now();
let bpe = &BPE_O200K;
println!("{:?}", time.elapsed());
let encoded1 = o200k_base_singleton()
.lock()
.encode_ordinary(test_string)
.into_iter()
.map(|t| t as u32)
.collect_vec();
let encoded2 = bpe.encode_via_backtracking(test_string.as_bytes());
assert_eq!(encoded1, encoded2);
let encoded3 = bpe.encode_via_table(test_string.as_bytes());
assert_eq!(encoded1, encoded3);
let encoded4 = bpe.encode_via_bitfield(test_string.as_bytes());
assert_eq!(encoded1, encoded4);
}

#[test]
fn test_bpe_equivalence() {
let bpe = &BPE_CL100K;
for tokens in [10, 1000, 10000] {
for _ in 0..5 {
let test_input = create_test_bytes(bpe, tokens);
let encoded1 = bpe.encode_via_backtracking(&test_input);
let encoded2 = bpe.encode_via_bitfield(&test_input);
assert_eq!(encoded1, encoded2, "{} {}", encoded1.len(), encoded2.len());
}
}
}
}
25 changes: 0 additions & 25 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,3 @@ impl<'a> IntervalEncoding<'a> {
encoder.count()
}
}

#[cfg(test)]
mod tests {
use rand::{thread_rng, Rng};

use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::IntervalEncoding;

#[test]
fn test_interval_count() {
let bpe = &BPE_CL100K;
let text = create_test_bytes(bpe, 10000);
let intervals = IntervalEncoding::new(bpe, &text);
for _ in 0..1000 {
let start = thread_rng().gen_range(0..text.len());
let end = thread_rng().gen_range(0..text.len());
let range = start.min(end)..start.max(end);
assert_eq!(
intervals.count(range.clone()),
bpe.encode_via_backtracking(&text[range]).len()
);
}
}
}
61 changes: 0 additions & 61 deletions crates/bpe/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,3 @@ mod bitfield;
pub mod byte_pair_encoding;
pub mod interval_encoding;
pub mod prependable_encoder;

#[cfg(test)]
mod tests {
use itertools::Itertools;

use crate::byte_pair_encoding::BytePairEncoding;

/// This test produces the output for the encoding example in the README.
#[test]
fn readme_example() {
let tokens = ["a", "b", "c", "ab", "cb", "ac"].map(|t| t.as_bytes().to_vec());
let bpe = BytePairEncoding::from_dictionary(tokens, None);
let text = "abacb";
let prefixes = (1..=text.len()).map(|end| &text[..end]).collect_vec();
let all_prefix_tokens = prefixes
.iter()
.map(|prefix| {
bpe.encode_via_backtracking(prefix.as_bytes())
.into_iter()
.map(|t| unsafe { String::from_utf8_unchecked(bpe.decode_tokens(&[t])) })
.collect_vec()
})
.collect_vec();
let last_prefix_tokens = all_prefix_tokens
.iter()
.map(|tokens| tokens.last().unwrap())
.collect_vec();

println!("All tokens for each prefix of `{text}`:\n");
for (prefix, tokens) in prefixes.iter().zip(&all_prefix_tokens) {
println!(
"- `{prefix}` {}> `{}`",
"-".repeat(text.len() + 2 - prefix.len()),
tokens.join(" ")
);
}
println!();

println!("Last token for each prefix of `{text}`:\n");
for (prefix, token) in prefixes.iter().zip(&last_prefix_tokens) {
println!(
"- `{prefix}` {}> `{token}`",
"-".repeat(text.len() + 2 - prefix.len()),
);
}
println!();

println!("Tokenization of `{text}`:\n");
let mut remaining = text.len();
while remaining > 0 {
let prefix = &text[..remaining];
let token = last_prefix_tokens[remaining - 1];
println!(
"- `{prefix}` {}> `{token}`",
"-".repeat(text.len() + 2 - prefix.len()),
);
remaining -= token.len();
}
println!("- `<empty>`");
}
}
18 changes: 0 additions & 18 deletions crates/bpe/src/prependable_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,3 @@ impl<'a> PrependableEncoder<'a> {
self.states.is_empty()
}
}

#[cfg(test)]
mod tests {
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::PrependableEncoder;

#[test]
fn test_prependable_encoder() {
let bpe = &BPE_CL100K;
let mut enc = PrependableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate().rev() {
enc.push(*c);
assert_eq!(enc.token_count(), bpe.count(&input_string[i..]));
}
}
}
10 changes: 10 additions & 0 deletions crates/bpe/tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "bpe-tests"
edition = "2021"

[dependencies]
bpe = { path = "../../bpe" }
bpe-openai = { path = "../../bpe-openai" }
itertools = "0.13"
rand = "0.8"
tiktoken-rs = "0.6"
Loading

0 comments on commit c29a180

Please sign in to comment.