Skip to content

Commit

Permalink
switch to regex::Regex (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 authored Nov 9, 2024
1 parent a52c83f commit 8ed2e82
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 44 deletions.
112 changes: 86 additions & 26 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use fancy_regex::Regex;
use fancy_regex::Regex as FancyRegex;
use regex::Regex;
use rustc_hash::FxHashMap as HashMap;
use rustc_hash::FxHashSet as HashSet;
use thread_local::ThreadLocal;
Expand Down Expand Up @@ -133,9 +134,9 @@ pub struct CoreBPE {
decoder: HashMap<Rank, &'static [u8]>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
regex: Regex,
special_regex: Regex,
special_regex: FancyRegex,
regex_tls: ThreadLocal<Regex>,
special_regex_tls: ThreadLocal<Regex>,
special_regex_tls: ThreadLocal<FancyRegex>,
sorted_token_bytes: Vec<&'static [u8]>,
}

Expand All @@ -144,7 +145,7 @@ impl CoreBPE {
self.regex_tls.get_or(|| self.regex.clone())
}

fn _get_tl_special_regex(&self) -> &Regex {
fn _get_tl_special_regex(&self) -> &FancyRegex {
self.special_regex_tls.get_or(|| self.special_regex.clone())
}

Expand All @@ -161,24 +162,85 @@ impl CoreBPE {
ret
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
fn _encode_ordinary_native_impl(&self, text: &str, ret: &mut Vec<Rank>) -> usize {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut last_end = 0;
let mut last_piece_token_len = 0;
let mut piece: &[u8] = &[];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
piece = mat.as_str().as_bytes();
let start = mat.start();
let end = mat.end();

// If there is a whitespace gap between peice and the previous piece, add its tokens
if last_end < start {
// If current piece starts with a whitespace, the whole gap is one new piece
if mat
.as_str()
.chars()
.next()
.map_or(false, |c| c.is_whitespace())
{
let wpiece = text[last_end..start].as_bytes();
match self.encoder.get(wpiece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece, &self.encoder)),
}
// otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
} else {
let last_char_size = &text[last_end..start]
.chars()
.next_back()
.unwrap()
.len_utf8();
// Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
// so the gap makes just one piece
if last_char_size < &(start - last_end) {
let wpiece1 = text[last_end..start - last_char_size].as_bytes();
match self.encoder.get(wpiece1) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece1, &self.encoder)),
}
}
let wpiece2 = text[start - last_char_size..start].as_bytes();
match self.encoder.get(wpiece2) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece2, &self.encoder)),
}
}
}
last_end = end;

// Now add piece tokens
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}
ret
// Gap of whitespaces at the end of text
if last_end < text.len() {
piece = text[last_end..text.len()].as_bytes();
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}

if !piece.is_empty() {
last_piece_token_len = match self.encoder.get(piece) {
Some(token) => 1,
None => byte_pair_encode(piece, &self.encoder).len(),
};
};

last_piece_token_len
}

fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();

let mut ret = vec![];

let mut start = 0;
Expand All @@ -201,17 +263,10 @@ impl CoreBPE {
}
let end = next_special.map_or(text.len(), |m| m.start());

// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
if end > start {
// regex is not created and passed here, but it seems harmless.
last_piece_token_len =
self._encode_ordinary_native_impl(&text[start..end], &mut ret);
}

match next_special {
Expand Down Expand Up @@ -271,6 +326,13 @@ impl CoreBPE {
(tokens, last_piece_token_len)
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
// This wrapper function is needed for those callers that do not pass ret.
let mut ret = vec![];
self._encode_ordinary_native_impl(text, &mut ret);
ret
}

fn _encode_unstable_native(
&self,
text: &str,
Expand Down Expand Up @@ -302,7 +364,7 @@ impl CoreBPE {
// Separating this from the loop below helps with performance in a common case.
let mut point = self
.sorted_token_bytes
.partition_point(|x| *x < unstable_bytes.as_slice());
.partition_point(|x| &x[..] < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
Expand All @@ -318,9 +380,7 @@ impl CoreBPE {
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| *x < suffix);
let mut point = self.sorted_token_bytes.partition_point(|x| &x[..] < suffix);
// TODO: Perf optimisation if suffix starts with " "?
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
Expand Down Expand Up @@ -393,15 +453,15 @@ impl CoreBPE {
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> Result<Self, fancy_regex::Error> {
) -> Result<Self, regex::Error> {
let regex = Regex::new(pattern)?;

let special_regex = {
let parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&parts.join("|"))?
FancyRegex::new(&parts.join("|")).unwrap()
};

// Use unsafe to extend the lifetime of references to the encoder's keys
Expand Down
13 changes: 0 additions & 13 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ include!(concat!(env!("OUT_DIR"), "/odht_gen.rs"));
pub struct Encoding {
/// The name of the encoding.
pub name: String,
/// The regular expression pattern used to split text into pieces.
pat_str: String,
/// The maximum length of the keys in `mergeable_ranks`.
mergeable_ranks_max_key_len: usize,
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
Expand Down Expand Up @@ -117,7 +115,6 @@ impl Encoding {

Ok(Self {
name: name.to_string(),
pat_str: pat_str.to_string(),
mergeable_ranks_max_key_len,
prefixes_of_mergeable_ranks,
special_tokens,
Expand Down Expand Up @@ -468,16 +465,6 @@ impl Encoding {
self.core_bpe.encode_single_piece(text_or_bytes)
}

/// Encodes a string into tokens, but do regex splitting in Rust.
fn _encode_only_native_bpe(&self, text: &str) -> Vec<Rank> {
let re = Regex::new(&self.pat_str).unwrap();
let mut ret = Vec::new();
for piece in re.find_iter(text) {
ret.extend(self.core_bpe.encode_single_piece(piece.as_str().as_bytes()));
}
ret
}

/// Encodes bytes into tokens.
fn _encode_bytes(&self, text: &[u8]) -> Vec<Rank> {
self.core_bpe._encode_bytes(text)
Expand Down
8 changes: 3 additions & 5 deletions src/openai_public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl EncodingFactory {
// The pattern in the original GPT-2 release is:
// r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
// This is equivalent, but executes faster:
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+";

pub fn gpt2() -> Result<Encoding, EncodingFactoryError> {
// todo!
Expand Down Expand Up @@ -114,7 +114,7 @@ impl EncodingFactory {
special_tokens.shrink_to_fit();
// use faster version from tiktoken upstream https://github.com/openai/tiktoken/pull/258/files#r1487668172
// const PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]";
Encoding::new(
"cl100k_base",
PATTERN,
Expand Down Expand Up @@ -142,8 +142,6 @@ impl EncodingFactory {
r"\p{N}{1,3}",
r" ?[^\s\p{L}\p{N}]+[\r\n/]*",
r"\s*[\r\n]+",
r"\s+(?!\S)",
r"\s+",
].join("|");

Encoding::new("o200k_base", pat_str, mergeable_ranks, special_tokens, None)
Expand Down Expand Up @@ -204,7 +202,7 @@ impl EncodingFactory {
special_tokens.into_iter().enumerate().map(|(i, token)| (token, (num_base_tokens + i) as Rank)).collect();
special_tokens_map.shrink_to_fit();

let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+";

let vocab_size = num_base_tokens + special_tokens_map.len();
Encoding::new(
Expand Down

0 comments on commit 8ed2e82

Please sign in to comment.