Skip to content

Commit

Permalink
Generate serialized data in build script
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikvanantwerpen committed Oct 4, 2024
1 parent 1c2506d commit bae9f01
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 110 deletions.
24 changes: 24 additions & 0 deletions crates/bpe-openai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "bpe-openai"
version = "0.0.1"
edition = "2021"
description = "Prebuilt fast byte-pair encoders for OpenAI."
repository = "https://github.com/github/rust-gems"
license = "MIT"
keywords = ["tokenizer", "algorithm", "encoding", "bpe"]
categories = ["algorithms", "data-structures", "encoding", "science"]

[lib]
crate-type = ["lib", "staticlib"]
bench = false

[dependencies]
bpe = { version = "0.0.1", path = "../bpe" }
rmp-serde = "1"
serde = { version = "1" }

[build-dependencies]
bpe = { version = "0.0.1", path = "../bpe", features = ["tiktoken-rs"] }
rmp-serde = "1"
tiktoken-rs = { version = "0.5" }
serde = { version = "1" }
32 changes: 32 additions & 0 deletions crates/bpe-openai/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::env;
use std::fs::File;
use std::path::PathBuf;

use bpe::byte_pair_encoding::BytePairEncoding;
use serde::Serialize;
use tiktoken_rs::CoreBPE;

fn main() {
serialize_tokens(
"cl100k",
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
100256,
17846336922010275747,
);
serialize_tokens(
"o200k",
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
199998,
17846336922010275747,
);
println!("cargo::rerun-if-changed=build.rs");
}

fn serialize_tokens(name: &str, bpe: &CoreBPE, num_tokens: usize, hash_factor: u64) {
let mut path = PathBuf::from(env::var("OUT_DIR").unwrap());
path.push(format!("bpe_{name}.dict"));
let file = File::create(path).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
let bpe = BytePairEncoding::from_tiktoken(bpe, num_tokens, Some(hash_factor));
bpe.serialize(&mut serializer).unwrap();
}
38 changes: 38 additions & 0 deletions crates/bpe-openai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::sync::LazyLock;

use bpe::byte_pair_encoding::BytePairEncoding;

static BPE_CL100K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_cl100k.dict"));
rmp_serde::from_slice(bytes).expect("")
});

static BPE_O200K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/bpe_o200k.dict"));
rmp_serde::from_slice(bytes).expect("")
});

pub use bpe::*;

pub fn cl100k() -> &'static BytePairEncoding {
&BPE_CL100K
}

pub fn o200k() -> &'static BytePairEncoding {
&BPE_O200K
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn can_load_cl100k() {
cl100k().count("".as_bytes());
}

#[test]
fn can_load_o200k() {
o200k().count("".as_bytes());
}
}
33 changes: 20 additions & 13 deletions crates/bpe/benches/performance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,28 @@ use criterion::{
use rand::{thread_rng, Rng};
use tiktoken_rs::CoreBPE;

static TOKENIZERS: LazyLock<[(&'static str, &'static BytePairEncoding, CoreBPE); 2]> =
LazyLock::new(|| {
[
(
"cl100k",
BytePairEncoding::cl100k(),
tiktoken_rs::cl100k_base().unwrap(),
static TOKENIZERS: LazyLock<[(&'static str, BytePairEncoding, CoreBPE); 2]> = LazyLock::new(|| {
[
(
"cl100k",
BytePairEncoding::from_tiktoken(
&tiktoken_rs::cl100k_base_singleton().lock(),
100256,
Some(17846336922010275747),
),
(
"o200k",
BytePairEncoding::o200k(),
tiktoken_rs::o200k_base().unwrap(),
tiktoken_rs::cl100k_base().unwrap(),
),
(
"o200k",
BytePairEncoding::from_tiktoken(
&tiktoken_rs::o200k_base_singleton().lock(),
199998,
Some(17846336922010275747),
),
]
});
tiktoken_rs::o200k_base().unwrap(),
),
]
});

fn counting_benchmark(c: &mut Criterion) {
for (name, bpe, _) in TOKENIZERS.iter() {
Expand Down
4 changes: 2 additions & 2 deletions crates/bpe/src/appendable_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ impl<'a> AppendableEncoder<'a> {

#[cfg(test)]
mod tests {
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::AppendableEncoder;

#[test]
fn test_appendable_encoder() {
let bpe = BytePairEncoding::cl100k();
let bpe = &BPE_CL100K;
let mut enc = AppendableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate() {
Expand Down
115 changes: 24 additions & 91 deletions crates/bpe/src/byte_pair_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::hash::{Hash, Hasher};
use std::ops::Range;
use std::sync::LazyLock;

use aneubeck_daachorse::{DoubleArrayAhoCorasick, DoubleArrayAhoCorasickBuilder};
use fnv::{FnvHashMap, FnvHasher};
Expand All @@ -12,19 +11,26 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::backtrack_encoder::BacktrackEncoder;
use crate::bitfield::BitField;
use crate::byte_pair_encoding::data::TokenDict;

static BPE_CL100K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
let bytes = include_bytes!("data/bpe_cl100k.dict");
let dict: TokenDict = rmp_serde::from_slice(bytes).expect("");
dict.into_bpe()
});
#[cfg(test)]
pub(crate) static BPE_CL100K: std::sync::LazyLock<BytePairEncoding> =
std::sync::LazyLock::new(|| {
BytePairEncoding::from_tiktoken(
&tiktoken_rs::cl100k_base_singleton().lock(),
100256,
Some(17846336922010275747),
)
});

static BPE_O200K: LazyLock<BytePairEncoding> = LazyLock::new(|| {
let bytes = include_bytes!("data/bpe_o200k.dict");
let dict: TokenDict = rmp_serde::from_slice(bytes).expect("");
dict.into_bpe()
});
#[cfg(test)]
pub(crate) static BPE_O200K: std::sync::LazyLock<BytePairEncoding> =
std::sync::LazyLock::new(|| {
BytePairEncoding::from_tiktoken(
&tiktoken_rs::o200k_base_singleton().lock(),
199998,
Some(17846336922010275747),
)
});

/// Representation of the byte pair dictionary.
/// This struct provides various conversions.
Expand Down Expand Up @@ -215,14 +221,6 @@ fn find_token_by_bytes(
}

impl BytePairEncoding {
pub fn cl100k() -> &'static Self {
&BPE_CL100K
}

pub fn o200k() -> &'static Self {
&BPE_O200K
}

/// Construct a BytePairEncoding instance from a tiktoken dictionary.
/// A suitable hash factor may be necessary to prevent hash collisions,
/// which can by found using [`find_hash_factor_for_tiktoken`].
Expand Down Expand Up @@ -572,7 +570,7 @@ mod tests {
use itertools::Itertools;
use tiktoken_rs::{cl100k_base_singleton, o200k_base_singleton};

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K, BPE_O200K};

#[test]
fn test_correctness_cl100k() {
Expand All @@ -585,9 +583,9 @@ mod tests {
])
.unwrap();
let time = Instant::now();
let bpe = BytePairEncoding::o200k();
let bpe = &BPE_CL100K;
println!("{:?}", time.elapsed());
let encoded1 = o200k_base_singleton()
let encoded1 = cl100k_base_singleton()
.lock()
.encode_ordinary(test_string)
.into_iter()
Expand All @@ -612,9 +610,9 @@ mod tests {
])
.unwrap();
let time = Instant::now();
let bpe = BytePairEncoding::cl100k();
let bpe = &BPE_O200K;
println!("{:?}", time.elapsed());
let encoded1 = cl100k_base_singleton()
let encoded1 = o200k_base_singleton()
.lock()
.encode_ordinary(test_string)
.into_iter()
Expand All @@ -630,7 +628,7 @@ mod tests {

#[test]
fn test_bpe_equivalence() {
let bpe = BytePairEncoding::cl100k();
let bpe = &BPE_CL100K;
for tokens in [10, 1000, 10000] {
for _ in 0..5 {
let test_input = create_test_bytes(bpe, tokens);
Expand All @@ -641,68 +639,3 @@ mod tests {
}
}
}

mod data {
use serde::{Deserialize, Serialize};

use crate::byte_pair_encoding::BytePairEncoding;

#[derive(Serialize, Deserialize)]
pub(crate) struct TokenDict {
tokens: Vec<Vec<u8>>,
hash_factor: u64,
}

impl TokenDict {
pub(crate) fn into_bpe(self) -> BytePairEncoding {
BytePairEncoding::from_dictionary(self.tokens, Some(self.hash_factor))
}
}

#[test]
fn update_token_dicts() {
serialize_tokens(
"cl100k",
&tiktoken_rs::cl100k_base().expect("tiktoken initialization must not fail!"),
100256,
17846336922010275747,
);
serialize_tokens(
"o200k",
&tiktoken_rs::o200k_base().expect("tiktoken initialization must not fail!"),
199998,
17846336922010275747,
);
}

#[cfg(test)]
#[track_caller]
fn serialize_tokens(
name: &str,
bpe: &tiktoken_rs::CoreBPE,
num_tokens: usize,
hash_factor: u64,
) {
use std::fs::File;
use std::path::PathBuf;

use itertools::Itertools;
use serde::Serialize;

let path = PathBuf::from(file!());
let dir = path.parent().unwrap();
let data_file = dir.join(format!("data/bpe_{name}.dict"));
let current_dir = std::env::current_dir().unwrap();
let abs_path = current_dir.parent().unwrap().parent().unwrap();
let file = File::create(abs_path.join(data_file)).unwrap();
let mut serializer = rmp_serde::Serializer::new(file);
let tokens = (0..num_tokens)
.map(|i| bpe._decode_native(&[i]))
.collect_vec();
let dict = TokenDict {
tokens,
hash_factor,
};
dict.serialize(&mut serializer).unwrap();
}
}
Binary file removed crates/bpe/src/data/bpe_cl100k.dict
Binary file not shown.
Binary file removed crates/bpe/src/data/bpe_o200k.dict
Binary file not shown.
4 changes: 2 additions & 2 deletions crates/bpe/src/interval_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ impl<'a> IntervalEncoding<'a> {
mod tests {
use rand::{thread_rng, Rng};

use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::IntervalEncoding;

#[test]
fn test_interval_count() {
let bpe = BytePairEncoding::cl100k();
let bpe = &BPE_CL100K;
let text = create_test_bytes(bpe, 10000);
let intervals = IntervalEncoding::new(bpe, &text);
for _ in 0..1000 {
Expand Down
4 changes: 2 additions & 2 deletions crates/bpe/src/prependable_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ impl<'a> PrependableEncoder<'a> {

#[cfg(test)]
mod tests {
use crate::byte_pair_encoding::{create_test_bytes, BytePairEncoding};
use crate::byte_pair_encoding::{create_test_bytes, BPE_CL100K};

use super::PrependableEncoder;

#[test]
fn test_prependable_encoder() {
let bpe = BytePairEncoding::cl100k();
let bpe = &BPE_CL100K;
let mut enc = PrependableEncoder::new(bpe);
let input_string = create_test_bytes(bpe, 100);
for (i, c) in input_string.iter().enumerate().rev() {
Expand Down

0 comments on commit bae9f01

Please sign in to comment.