Skip to content

Commit

Permalink
use updated regexes from upstream (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 authored Oct 17, 2024
1 parent a4b3165 commit 995c7a9
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/openai_public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ const IM_SEP: &str = "<|im_sep|>";
#[derive(Clone, Debug, Copy)]
pub struct EncodingFactory {}
impl EncodingFactory {
// The pattern in the original GPT-2 release is:
// r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
// This is equivalent, but executes faster:
const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";

pub fn gpt2() -> Result<Encoding, EncodingFactoryError> {
// todo!
// vocab_bpe_file: sha256 = 1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5
Expand All @@ -45,7 +50,7 @@ impl EncodingFactory {
special_tokens.shrink_to_fit();
Encoding::new(
"r50k_base",
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+",
EncodingFactory::LEGACY_SPLITTER_REGEX,
mergeable_ranks,
special_tokens,
Some(50257),
Expand All @@ -64,7 +69,7 @@ impl EncodingFactory {
special_tokens.shrink_to_fit();
Encoding::new(
"p50k_base",
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+",
EncodingFactory::LEGACY_SPLITTER_REGEX,
mergeable_ranks,
special_tokens,
Some(50281),
Expand Down Expand Up @@ -107,9 +112,12 @@ impl EncodingFactory {
.map_err(|_| EncodingFactoryError::FailedToLoadEncoding)?;
let mut special_tokens: HashMap<String, Rank> = special_tokens.iter().cloned().collect();
special_tokens.shrink_to_fit();
// use faster version from tiktoken upstream https://github.com/openai/tiktoken/pull/258/files#r1487668172
// const PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
Encoding::new(
"cl100k_base",
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
PATTERN,
mergeable_ranks,
special_tokens,
None,
Expand Down

0 comments on commit 995c7a9

Please sign in to comment.