From 995c7a9b8290e4f7ca9656c13e146856c8bbe071 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 16 Oct 2024 19:05:17 -0700 Subject: [PATCH] use updated regexes from upstream (#19) via https://github.com/openai/tiktoken/pull/258/files --- src/openai_public.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/openai_public.rs b/src/openai_public.rs index 4f3f389..f1d5bc3 100644 --- a/src/openai_public.rs +++ b/src/openai_public.rs @@ -26,6 +26,11 @@ const IM_SEP: &str = "<|im_sep|>"; #[derive(Clone, Debug, Copy)] pub struct EncodingFactory {} impl EncodingFactory { + // The pattern in the original GPT-2 release is: + // r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" + // This is equivalent, but executes faster: + const LEGACY_SPLITTER_REGEX: &str = r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s"; + pub fn gpt2() -> Result { // todo! // vocab_bpe_file: sha256 = 1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5 @@ -45,7 +50,7 @@ impl EncodingFactory { special_tokens.shrink_to_fit(); Encoding::new( "r50k_base", - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+", + EncodingFactory::LEGACY_SPLITTER_REGEX, mergeable_ranks, special_tokens, Some(50257), @@ -64,7 +69,7 @@ impl EncodingFactory { special_tokens.shrink_to_fit(); Encoding::new( "p50k_base", - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+", + EncodingFactory::LEGACY_SPLITTER_REGEX, mergeable_ranks, special_tokens, Some(50281), @@ -107,9 +112,12 @@ impl EncodingFactory { .map_err(|_| EncodingFactoryError::FailedToLoadEncoding)?; let mut special_tokens: HashMap = special_tokens.iter().cloned().collect(); special_tokens.shrink_to_fit(); + // use faster version from tiktoken upstream https://github.com/openai/tiktoken/pull/258/files#r1487668172 + // const PATTERN: &str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; + const PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s"; Encoding::new( "cl100k_base", - r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", + PATTERN, mergeable_ranks, special_tokens, None,