Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to regex::Regex #22

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 86 additions & 26 deletions src/corebpe.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use fancy_regex::Regex;
use fancy_regex::Regex as FancyRegex;
use regex::Regex;
use rustc_hash::FxHashMap as HashMap;
use rustc_hash::FxHashSet as HashSet;
use thread_local::ThreadLocal;
Expand Down Expand Up @@ -133,9 +134,9 @@ pub struct CoreBPE {
decoder: HashMap<Rank, &'static [u8]>,
special_tokens_decoder: HashMap<Rank, Vec<u8>>,
regex: Regex,
special_regex: Regex,
special_regex: FancyRegex,
regex_tls: ThreadLocal<Regex>,
special_regex_tls: ThreadLocal<Regex>,
special_regex_tls: ThreadLocal<FancyRegex>,
sorted_token_bytes: Vec<&'static [u8]>,
}

Expand All @@ -144,7 +145,7 @@ impl CoreBPE {
self.regex_tls.get_or(|| self.regex.clone())
}

fn _get_tl_special_regex(&self) -> &Regex {
fn _get_tl_special_regex(&self) -> &FancyRegex {
self.special_regex_tls.get_or(|| self.special_regex.clone())
}

Expand All @@ -161,24 +162,85 @@ impl CoreBPE {
ret
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
fn _encode_ordinary_native_impl(&self, text: &str, ret: &mut Vec<Rank>) -> usize {
// This is the core of the encoding logic; the other functions in here
// just make things complicated :-)
let regex = self._get_tl_regex();
let mut ret = vec![];
let mut last_end = 0;
let mut last_piece_token_len = 0;
let mut piece: &[u8] = &[];
for mat in regex.find_iter(text) {
let piece = mat.unwrap().as_str().as_bytes();
piece = mat.as_str().as_bytes();
let start = mat.start();
let end = mat.end();

// If there is a whitespace gap between peice and the previous piece, add its tokens
if last_end < start {
// If current piece starts with a whitespace, the whole gap is one new piece
if mat
.as_str()
.chars()
.next()
.map_or(false, |c| c.is_whitespace())
{
let wpiece = text[last_end..start].as_bytes();
match self.encoder.get(wpiece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece, &self.encoder)),
}
// otherwise the last char of gap makes a piece, and the rest (if any) makes another piece
} else {
let last_char_size = &text[last_end..start]
.chars()
.next_back()
.unwrap()
.len_utf8();
// Example for gpt4-o: for text "= 6", "=" and "6" are matches, " " is the gap,
// so the gap makes just one piece
if last_char_size < &(start - last_end) {
let wpiece1 = text[last_end..start - last_char_size].as_bytes();
match self.encoder.get(wpiece1) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece1, &self.encoder)),
}
}
let wpiece2 = text[start - last_char_size..start].as_bytes();
match self.encoder.get(wpiece2) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(wpiece2, &self.encoder)),
}
}
}
last_end = end;

// Now add piece tokens
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}
ret
// Gap of whitespaces at the end of text
if last_end < text.len() {
piece = text[last_end..text.len()].as_bytes();
match self.encoder.get(piece) {
Some(token) => ret.push(*token),
None => ret.extend(&byte_pair_encode(piece, &self.encoder)),
}
}

if !piece.is_empty() {
last_piece_token_len = match self.encoder.get(piece) {
Some(token) => 1,
None => byte_pair_encode(piece, &self.encoder).len(),
};
};

last_piece_token_len
}

fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec<Rank>, usize) {
let special_regex = self._get_tl_special_regex();
let regex = self._get_tl_regex();

let mut ret = vec![];

let mut start = 0;
Expand All @@ -201,17 +263,10 @@ impl CoreBPE {
}
let end = next_special.map_or(text.len(), |m| m.start());

// Okay, here we go, compare this logic to _encode_ordinary_native
for mat in regex.find_iter(&text[start..end]) {
let piece = mat.unwrap().as_str().as_bytes();
if let Some(token) = self.encoder.get(piece) {
last_piece_token_len = 1;
ret.push(*token);
continue;
}
let tokens = byte_pair_encode(piece, &self.encoder);
last_piece_token_len = tokens.len();
ret.extend(&tokens);
if end > start {
// regex is not created and passed here, but it seems harmless.
last_piece_token_len =
self._encode_ordinary_native_impl(&text[start..end], &mut ret);
}

match next_special {
Expand Down Expand Up @@ -271,6 +326,13 @@ impl CoreBPE {
(tokens, last_piece_token_len)
}

fn _encode_ordinary_native(&self, text: &str) -> Vec<Rank> {
// This wrapper function is needed for those callers that do not pass ret.
let mut ret = vec![];
self._encode_ordinary_native_impl(text, &mut ret);
ret
}

fn _encode_unstable_native(
&self,
text: &str,
Expand Down Expand Up @@ -302,7 +364,7 @@ impl CoreBPE {
// Separating this from the loop below helps with performance in a common case.
let mut point = self
.sorted_token_bytes
.partition_point(|x| *x < unstable_bytes.as_slice());
.partition_point(|x| &x[..] < unstable_bytes.as_slice());
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(&unstable_bytes)
{
Expand All @@ -318,9 +380,7 @@ impl CoreBPE {
for i in 1..unstable_bytes.len() {
let prefix = &unstable_bytes[..i];
let suffix = &unstable_bytes[i..];
let mut point = self
.sorted_token_bytes
.partition_point(|x| *x < suffix);
let mut point = self.sorted_token_bytes.partition_point(|x| &x[..] < suffix);
// TODO: Perf optimisation if suffix starts with " "?
while point < self.sorted_token_bytes.len()
&& self.sorted_token_bytes[point].starts_with(suffix)
Expand Down Expand Up @@ -393,15 +453,15 @@ impl CoreBPE {
encoder: HashMap<Vec<u8>, Rank>,
special_tokens_encoder: HashMap<String, Rank>,
pattern: &str,
) -> Result<Self, fancy_regex::Error> {
) -> Result<Self, regex::Error> {
let regex = Regex::new(pattern)?;

let special_regex = {
let parts = special_tokens_encoder
.keys()
.map(|s| fancy_regex::escape(s))
.collect::<Vec<_>>();
Regex::new(&parts.join("|"))?
FancyRegex::new(&parts.join("|")).unwrap()
};

// Use unsafe to extend the lifetime of references to the encoder's keys
Expand Down
13 changes: 0 additions & 13 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ include!(concat!(env!("OUT_DIR"), "/odht_gen.rs"));
pub struct Encoding {
/// The name of the encoding.
pub name: String,
/// The regular expression pattern used to split text into pieces.
pat_str: String,
/// The maximum length of the keys in `mergeable_ranks`.
mergeable_ranks_max_key_len: usize,
/// All prefixes of the mergeable ranks. May or may not be tokens themselves!
Expand Down Expand Up @@ -117,7 +115,6 @@ impl Encoding {

Ok(Self {
name: name.to_string(),
pat_str: pat_str.to_string(),
mergeable_ranks_max_key_len,
prefixes_of_mergeable_ranks,
special_tokens,
Expand Down Expand Up @@ -468,16 +465,6 @@ impl Encoding {
self.core_bpe.encode_single_piece(text_or_bytes)
}

/// Encodes a string into tokens, but do regex splitting in Rust.
fn _encode_only_native_bpe(&self, text: &str) -> Vec<Rank> {
let re = Regex::new(&self.pat_str).unwrap();
let mut ret = Vec::new();
for piece in re.find_iter(text) {
ret.extend(self.core_bpe.encode_single_piece(piece.as_str().as_bytes()));
}
ret
}

/// Encodes bytes into tokens.
fn _encode_bytes(&self, text: &[u8]) -> Vec<Rank> {
self.core_bpe._encode_bytes(text)
Expand Down
8 changes: 3 additions & 5 deletions src/openai_public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl EncodingFactory {
// The pattern in the original GPT-2 release is:
// r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
// This is equivalent, but executes faster:
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+";

pub fn gpt2() -> Result<Encoding, EncodingFactoryError> {
// todo!
Expand Down Expand Up @@ -114,7 +114,7 @@ impl EncodingFactory {
special_tokens.shrink_to_fit();
// use faster version from tiktoken upstream https://github.com/openai/tiktoken/pull/258/files#r1487668172
// const PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]";
Encoding::new(
"cl100k_base",
PATTERN,
Expand Down Expand Up @@ -142,8 +142,6 @@ impl EncodingFactory {
r"\p{N}{1,3}",
r" ?[^\s\p{L}\p{N}]+[\r\n/]*",
r"\s*[\r\n]+",
r"\s+(?!\S)",
r"\s+",
].join("|");

Encoding::new("o200k_base", pat_str, mergeable_ranks, special_tokens, None)
Expand Down Expand Up @@ -204,7 +202,7 @@ impl EncodingFactory {
special_tokens.into_iter().enumerate().map(|(i, token)| (token, (num_base_tokens + i) as Rank)).collect();
special_tokens_map.shrink_to_fit();

let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
let pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+";

let vocab_size = num_base_tokens + special_tokens_map.len();
Encoding::new(
Expand Down
Loading