diff --git a/.gitignore b/.gitignore index aeb0f0eb..9dbb7b2a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ benchmarks/results # Remove doc build folders .cache/ build/ - +rust-coverage/ target/ *.so *.pyd diff --git a/Cargo.toml b/Cargo.toml index 0e83d020..94eab3a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,12 +7,17 @@ license = "Apache-2.0" repository = "https://github.com/dottxt-ai/outlines-core" [dependencies] +once_cell = "1.20" anyhow = "1.0.86" -thiserror = "1.0" +thiserror = "2.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } regex = "1.10.6" serde-pyobject = "0.4.0" -serde_json = { version = "1.0.125", features = ["preserve_order"] } +serde_json = { version = "1.0", features = ["preserve_order"] } +serde = {version = "1.0", features = ["derive"]} +# Fragile dependencies, minor updates often break the code +hf-hub = "=0.3.2" +tokenizers = { version = "=0.20.3", features = ["http"] } [features] python-bindings = ["pyo3"] @@ -31,3 +36,6 @@ panic = 'abort' [package.metadata.scripts] build-python-extension = "python setup.py build_rust --inplace --debug" build-python-extension-release = "python setup.py build_rust --inplace --release" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] } diff --git a/Makefile b/Makefile index f3293068..6cd0637d 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,9 @@ # Optional target to test/benchmark. TARGET ?= +TARPAULIN_INSTALLED := $(shell command -v cargo-tarpaulin > /dev/null && echo 1 || echo 0) .ONESHELL: -.PHONY: venv setup install install-release build-extension-debug build-extension-release watch-extension watch-extension-release pcc test test-rust test-python bench pybench doc dist clean check-clean-git +.PHONY: venv setup install install-release build-extension-debug build-extension-release watch-extension watch-extension-release pcc test test-rust test-python bench pybench doc dist clean check-clean-git check-tarpaulin test-rust-cov .SILENT: # Create a fresh virtual environment with the latest pip. @@ -59,6 +60,26 @@ test-python: build-extension-debug --cov=outlines_core \ --cov-report=term-missing:skip-covered +# Check if tarpaulin needs to be installed first. +check-tarpaulin: +ifeq ($(TARPAULIN_INSTALLED), 0) + @echo "cargo-tarpaulin is not found, installing..." + cargo install cargo-tarpaulin +else + @echo "cargo-tarpaulin is already installed" +endif + +# Run rust tests with coverage report. +test-rust-cov: check-tarpaulin + RUSTFLAGS="-C instrument-coverage" cargo tarpaulin \ + --out=Lcov \ + --output-dir=rust-coverage \ + --engine=llvm \ + --exclude-files=src/python_bindings/* \ + --no-dead-code \ + --workspace \ + --verbose + # Run rust benchmarks. bench: ifeq ($(TARGET),) diff --git a/pyproject.toml b/pyproject.toml index 3d1d0d47..57090988 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ filterwarnings = [ "error", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::UserWarning", + "ignore::DeprecationWarning", ] addopts = [ "--import-mode=importlib" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 00000000..f589731c --- /dev/null +++ b/src/error.rs @@ -0,0 +1,41 @@ +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +#[error("{0}")] +pub struct TokenizersError(pub tokenizers::Error); + +impl PartialEq for TokenizersError { + fn eq(&self, other: &Self) -> bool { + self.0.to_string() == other.0.to_string() + } +} + +#[derive(Error, Debug, PartialEq)] +pub enum Error { + #[error("The vocabulary does not allow us to build a sequence that matches the input")] + IndexError, + #[error(transparent)] + TokenizersError(#[from] TokenizersError), + #[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")] + UnsupportedTokenizer { model: String, reason: String }, + #[error("Unable to locate EOS token for {model}")] + UnableToLocateEosTokenId { model: String }, + #[error("Tokenizer is not supported by token processor")] + UnsupportedByTokenProcessor, + #[error("Decoder unpacking failed for token processor")] + DecoderUnpackingFailed, + #[error("Token processing failed for byte level processor")] + ByteProcessorFailed, + #[error("Token processing failed for byte fallback level processor")] + ByteFallbackProcessorFailed, +} + +#[cfg(feature = "python-bindings")] +impl From for pyo3::PyErr { + fn from(e: Error) -> Self { + use pyo3::{exceptions::PyValueError, PyErr}; + PyErr::new::(e.to_string()) + } +} diff --git a/src/index.rs b/src/index.rs index 587cd76a..cc1187e8 100644 --- a/src/index.rs +++ b/src/index.rs @@ -2,10 +2,9 @@ use crate::prelude::{State, TransitionKey}; use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; +use crate::{Error, Result}; use std::collections::{HashMap, HashSet}; -pub type Result = std::result::Result; - #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, @@ -101,7 +100,7 @@ impl Index { eos_token_id, }) } else { - Err(crate::Error::IndexError) + Err(Error::IndexError) } } diff --git a/src/lib.rs b/src/lib.rs index 71787e2e..6155b718 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod error; pub mod index; pub mod json_schema; pub mod prelude; @@ -5,21 +6,7 @@ pub mod primitives; pub mod regex; pub mod vocabulary; -#[cfg(feature = "python-bindings")] -mod python_bindings; - -use thiserror::Error; - -#[derive(Error, Debug)] -pub enum Error { - #[error("The vocabulary does not allow us to build a sequence that matches the input")] - IndexError, -} +pub use error::{Error, Result}; #[cfg(feature = "python-bindings")] -impl From for pyo3::PyErr { - fn from(e: Error) -> Self { - use pyo3::{exceptions::PyValueError, PyErr}; - PyErr::new::(e.to_string()) - } -} +mod python_bindings; diff --git a/src/prelude.rs b/src/prelude.rs index e196e474..d42516b9 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -2,9 +2,3 @@ pub use super::{ primitives::{State, Token, TokenId, TransitionKey}, vocabulary::Vocabulary, }; - -pub(crate) use std::{ - collections::{HashMap, HashSet}, - fmt::{self, Display}, - ops::Deref, -}; diff --git a/src/regex.rs b/src/regex.rs index a41bf862..b5658191 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use std::collections::{HashMap, HashSet}; pub fn walk_fsm( fsm_transitions: &HashMap<(State, TransitionKey), State>, diff --git a/src/vocabulary.rs b/src/vocabulary.rs deleted file mode 100644 index f03df8f7..00000000 --- a/src/vocabulary.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::prelude::*; - -/// Vocabulary of an LLM. -/// -/// ## Examples -/// -/// ```rust -/// # use outlines_core::prelude::*; -/// # -/// let vocabulary = Vocabulary::new() -/// .insert("blah", 0) -/// .insert("1a", 1) -/// .insert("2", 2) -/// .insert("0", 3); -/// ``` -#[derive(Clone, Debug, Default)] -pub struct Vocabulary(pub(crate) HashMap>); - -impl Vocabulary { - /// Creates an empty vocabulary. - pub fn new() -> Vocabulary { - Vocabulary::default() - } -} - -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier. - pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { - self.insert_in_place(token, id); - self - } - - /// Extends the vocabulary with tokens and their identifiers. - pub fn extend, I: IntoIterator>( - mut self, - tokens_and_ids: impl IntoIterator, - ) -> Vocabulary { - self.extend_in_place(tokens_and_ids); - self - } -} - -impl Vocabulary { - /// Inserts a token to the vocabulary with the specified identifier, in place. - pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { - let token = token.into(); - self.0.entry(token).or_default().push(id); - } - - /// Extends the vocabulary with tokens and their identifiers, in place. - pub fn extend_in_place, I: IntoIterator>( - &mut self, - tokens_and_ids: impl IntoIterator, - ) { - for (token, ids) in tokens_and_ids.into_iter() { - let token = token.into(); - self.0.entry(token).or_default().extend(ids); - } - } -} - -impl Deref for Vocabulary { - type Target = HashMap>; - - fn deref(&self) -> &HashMap> { - &self.0 - } -} - -impl Display for Vocabulary { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (index, (token, token_ids)) in self.iter().enumerate() { - if index != (self.len() - 1) { - writeln!(f, "{:?} -> {:?}", token, token_ids)?; - } else { - write!(f, "{:?} -> {:?}", token, token_ids)?; - } - } - Ok(()) - } -} - -impl From>> for Vocabulary { - fn from(map: HashMap>) -> Vocabulary { - Vocabulary(map) - } -} - -impl FromIterator<(T, I)> for Vocabulary -where - T: Into, - I: IntoIterator, -{ - fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new().extend(tokens_and_ids) - } -} - -#[cfg(test)] -mod tests { - use crate::prelude::*; - - #[test] - fn insert() { - let vocabulary = Vocabulary::new() - .insert("blah", 0) - .insert("1a", 1) - .insert("2", 2) - .insert("0", 3); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); - } - - #[test] - fn extend() { - let vocabulary = Vocabulary::new().extend([ - ("blah", vec![0]), - ("1a", vec![1]), - ("2", vec![2]), - ("0", vec![3]), - ]); - - assert_eq!(vocabulary.len(), 4); - assert_eq!(vocabulary["blah"], &[0]); - assert_eq!(vocabulary["1a"], &[1]); - assert_eq!(vocabulary["2"], &[2]); - assert_eq!(vocabulary["0"], &[3]); - } -} diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs new file mode 100644 index 00000000..d3f8bcfc --- /dev/null +++ b/src/vocabulary/locator.rs @@ -0,0 +1,242 @@ +use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; +use serde::{Deserialize, Serialize}; +use tokenizers::{FromPretrainedParameters, Tokenizer}; + +use crate::primitives::*; + +/// Mapping of characters to bytes for GPT-2 like tokenizers. +/// List of common eos token locations appearing on hugging face hub, ordered by priority. +const COMMON_LOCATIONS: &[EosTokenLocation] = &[ + // Most projects have `generation_config.json` that looks like: + // { + // ... + // "eos_token_id": 50256, + // ... + // } + // So it's the first place we look for the eos token id. + // + // For example: + // - https://huggingface.co/openai-community/gpt2/blob/main/generation_config.json + EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Id, + }, + // The ones that don't have `generation_config.json` usually have `tokenizer_config.json`: + // { + // ... + // "eos_token": "<|endoftext|>", + // ... + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/microsoft/phi-2/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Value, + }, + // Sometimes `tokenizer_config.json` can have the following format as well: + // { + // "eos_token": { + // ... + // "content": "", + // ... + // }, + // } + // Once we have the eos token content, we can get its id from the tokenizer. + // + // For example: + // - https://huggingface.co/hf-internal-testing/llama-tokenizer/blob/main/tokenizer_config.json + EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Object, + }, +]; + +/// `Id` kind of `EosTokenField`, when `eos_token_id` provided as an id. +#[derive(Debug, Serialize, Deserialize)] +struct Id { + eos_token_id: u64, +} + +/// `Value` kind of `EosTokenField`, when `eos_token` provided as a text, so that its id +/// will be fetched from the tokenizer. +#[derive(Debug, Serialize, Deserialize)] +struct Value { + eos_token: String, +} + +/// `Object` kind of `EosTokenField`, when `eos_token` provided as a `Content`. +#[derive(Debug, Serialize, Deserialize)] +struct Object { + eos_token: Content, +} + +/// `eos_token` provided in a `Content`. +#[derive(Debug, Serialize, Deserialize)] +struct Content { + content: String, +} + +/// Specifies in which part in config's json to check for eos token id. +enum EosTokenField { + Id, + Value, + Object, +} + +/// Defines location of the end of sentence token id in the config file. +struct EosTokenLocation { + file: &'static str, + location: EosTokenField, +} + +/// Locates eos token id. +pub(crate) trait Locator { + /// Locates eos token id in defined locations by `Locator`. + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option; +} + +/// Locates eos token id by searching in defined common locations in hugging face. +pub(crate) struct HFLocator; + +impl Locator for HFLocator { + /// Locates eos token id in defined locations. + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) + } +} + +impl EosTokenLocation { + /// Finds eos token within defined location in a related config file. + fn lookup( + &self, + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + let file_path = Self::download_config(model, self.file, parameters).ok()?; + let file = std::fs::File::open(file_path).ok()?; + + match self.location { + EosTokenField::Id => { + let config: Id = serde_json::from_reader(file).ok()?; + u32::try_from(config.eos_token_id).ok() + } + EosTokenField::Value => { + let config: Value = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token) + } + EosTokenField::Object => { + let config: Object = serde_json::from_reader(file).ok()?; + tokenizer.token_to_id(&config.eos_token.content) + } + } + } + + /// Downloads related config file from Hugging Face Hub. + fn download_config( + project: &str, + file: &str, + parameters: &Option, + ) -> tokenizers::Result { + // Adapted from + // https://github.com/huggingface/tokenizers/blob/9b77c054ef4297c7057fa8db875368c7c02f1bfc/tokenizers/src/utils/from_pretrained.rs#L26 + + let params = parameters.clone().unwrap_or_default(); + + // Validation checks are coming as a literal adaptation logic from HF. + // In this case project is a model name, which if invalid expected to fail much earlier. + // So it seems a bit redundant to validate it this way, but no harm in doing so too. + Self::validate(project)?; + Self::validate(¶ms.revision)?; + + let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision); + let api = ApiBuilder::new() + .with_token(params.token) + .build()? + .repo(repo); + + Ok(api.get(file)?) + } + + fn validate(input: &str) -> tokenizers::Result<()> { + let valid_chars = ['-', '_', '.', '/']; + + if !input + .chars() + .all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c)) + { + return Err(format!( + "Input {input} contains invalid characters, expected only alphanumeric or {}", + valid_chars + .iter() + .map(|x| format!("'{}'", x)) + .collect::>() + .join(", ") + ) + .into()); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn common_locations() { + for (model, expected_token_id, expected_token) in &[ + ("openai-community/gpt2", 50256, "<|endoftext|>"), + ("microsoft/phi-2", 50256, "<|endoftext|>"), + ("hf-internal-testing/llama-tokenizer", 2, ""), + ] { + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let located = HFLocator::locate_eos_token_id(model, &tokenizer, &None) + .expect("Token id is not located"); + + assert_eq!(located, *expected_token_id); + assert_eq!( + tokenizer.id_to_token(located).expect("Token is not found"), + expected_token.to_string() + ); + } + } + + #[test] + fn bad_location() { + let bad_location = EosTokenLocation { + file: "tokenizer_config.json", + location: EosTokenField::Id, + }; + let model = "microsoft/phi-2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let token_id = bad_location.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + + let bad_file = EosTokenLocation { + file: "generation_config.json", + location: EosTokenField::Value, + }; + let token_id = bad_file.lookup(model, &tokenizer, &None); + assert!(token_id.is_none()); + } + + #[test] + fn validate_config_input() { + let input = "bad_model_name*"; + assert!(EosTokenLocation::validate(input).is_err()); + } +} diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs new file mode 100644 index 00000000..719c9040 --- /dev/null +++ b/src/vocabulary/mod.rs @@ -0,0 +1,469 @@ +use std::collections::HashMap; + +use tokenizers::normalizers::Sequence; +use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; + +use crate::{error, prelude::*}; +use crate::{Error, Result}; + +use locator::{HFLocator, Locator}; +use processor::TokenProcessor; + +mod locator; +mod processor; + +/// Vocabulary of an LLM. +/// +/// ## Examples +/// +/// ```rust +/// # use outlines_core::prelude::*; +/// # +/// let vocabulary = Vocabulary::new(None) +/// .insert("blah", 0) +/// .insert("1a", 1) +/// .insert("2", 2) +/// .insert("0", 3); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct Vocabulary { + // TODO: Option is temp for back compatibility + eos_token_id: Option, + tokens: HashMap>, +} + +impl Vocabulary { + /// Creates an empty vocabulary. + pub fn new(eos_token_id: Option) -> Self { + Self { + eos_token_id, + tokens: HashMap::new(), + } + } + + /// Creates the vocabulary of pre-trained model from Hugging Face Hub. + pub fn from_pretrained( + model: &str, + parameters: Option, + ) -> Result { + Self::from_pretrained_with_locator::(model, parameters) + } + + #[doc(hidden)] + #[inline(always)] + fn from_pretrained_with_locator( + model: &str, + parameters: Option, + ) -> Result { + let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()) + .map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?; + Self::filter_prepend_normalizers(&mut tokenizer); + + // Locate eos_token_id in defined locations. + let eos_token_id = L::locate_eos_token_id(model, &tokenizer, ¶meters); + let Some(eos_token_id) = eos_token_id else { + return Err(Error::UnsupportedTokenizer { + model: model.to_string(), + reason: "EOS token id".to_string(), + }); + }; + + // Start building the vocabulary from eos_token_id and added tokens. + let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { + if !added_token.special { + vocabulary = vocabulary.insert(added_token.content.clone(), *id); + } + } + + // Process each vocabulary token according to the tokenizer's level. + let Ok(processor) = TokenProcessor::new(&tokenizer) else { + return Err(Error::UnsupportedTokenizer { + model: model.to_string(), + reason: "Token processor".to_string(), + }); + }; + for (token, token_id) in tokenizer.get_vocab(false) { + let token_bytes = processor.process(token)?; + // TODO: lossy is temp: + // - in python in was handled by byte_symbol function + // - interface needs to be redefined to treat Token type as bytes: Vec + let processed_token = String::from_utf8_lossy(&token_bytes); + vocabulary = vocabulary.insert(processed_token, token_id); + } + + Ok(vocabulary) + } + + /// Per provided token returns vector of `TokenId`s if available in the vocabulary. + pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { + self.tokens.get(token) + } + + /// Gets the identifier of the special end of the sentence token. + pub fn eos_token_id(&self) -> Option { + self.eos_token_id + } + + /// Filters out `Prepend` kind of tokenizer's normalizers. + fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) { + // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece + // In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text, + // e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.] + // + // We don't want to deal with the special characters, so we remove `Prepend` normalizers. + if let Some(normalizer) = tokenizer.get_normalizer() { + match normalizer { + NormalizerWrapper::Sequence(normalization_sequence) => { + let new_sequence = Sequence::new( + normalization_sequence + .get_normalizers() + .iter() + .filter_map(|normalizer| match normalizer { + NormalizerWrapper::Prepend(_) => None, + _ => Some(normalizer.clone()), + }) + .collect(), + ); + tokenizer.with_normalizer(new_sequence.into()); + } + NormalizerWrapper::Prepend(_) => { + tokenizer.with_normalizer(None::); + } + _ => {} + } + } + } +} + +impl Vocabulary { + /// Inserts a token to the vocabulary with the specified identifier. + pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { + self.insert_in_place(token, id); + self + } + + /// Extends the vocabulary with tokens and their identifiers. + pub fn extend, I: IntoIterator>( + mut self, + tokens_and_ids: impl IntoIterator, + ) -> Vocabulary { + self.extend_in_place(tokens_and_ids); + self + } +} + +impl Vocabulary { + /// Inserts a token to the vocabulary with the specified identifier, in place. + pub fn insert_in_place(&mut self, token: impl Into, id: TokenId) { + // TODO: return error if eos token id is inserted + let token = token.into(); + self.tokens.entry(token).or_default().push(id); + } + + /// Extends the vocabulary with tokens and their identifiers, in place. + pub fn extend_in_place, I: IntoIterator>( + &mut self, + tokens_and_ids: impl IntoIterator, + ) { + for (token, ids) in tokens_and_ids.into_iter() { + let token = token.into(); + self.tokens.entry(token).or_default().extend(ids); + } + } +} + +impl std::ops::Deref for Vocabulary { + type Target = HashMap>; + + fn deref(&self) -> &HashMap> { + &self.tokens + } +} + +impl std::fmt::Display for Vocabulary { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (index, (token, token_ids)) in self.iter().enumerate() { + if index != (self.len() - 1) { + writeln!(f, "{:?} -> {:?}", token, token_ids)?; + } else { + write!(f, "{:?} -> {:?}", token, token_ids)?; + } + } + Ok(()) + } +} + +impl From>> for Vocabulary { + fn from(tokens: HashMap>) -> Vocabulary { + Vocabulary { + eos_token_id: None, + tokens, + } + } +} + +impl FromIterator<(T, I)> for Vocabulary +where + T: Into, + I: IntoIterator, +{ + fn from_iter>(tokens_and_ids: A) -> Self { + Vocabulary::new(None).extend(tokens_and_ids) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn insert() { + let vocabulary = Vocabulary::new(None) + .insert("blah", 0) + .insert("1a", 1) + .insert("2", 2) + .insert("0", 3); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } + + #[test] + fn extend() { + let vocabulary = Vocabulary::new(None).extend([ + ("blah", vec![0]), + ("1a", vec![1]), + ("2", vec![2]), + ("0", vec![3]), + ]); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } + + #[test] + fn new_empty_vocabulary() { + let vocabulary = Vocabulary::new(None); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_empty_vocabulary_from_hashmap() { + let map = HashMap::new(); + let vocabulary = Vocabulary::from(map); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_vocabulary_from_iterator() { + let token: Token = "abc".to_string(); + let id: Vec = vec![1]; + let it = vec![(token, id)]; + let vocabulary = Vocabulary::from_iter(it); + assert!(vocabulary.eos_token_id.is_none()); + assert!(!vocabulary.tokens.is_empty()); + } + + #[test] + fn supported_pretrained_models() { + // Support is expected for these: + for model in [ + // GPT 2 + "openai-community/gpt2", + // Llama 2 + "hf-internal-testing/Llama-2-7B-GPTQ", + // Llama 3 + // OpenCoder: shares llama tokenizers + "hf-internal-testing/llama-3-8b-internal", + // Qwen + "Qwen/Qwen2-7B-Instruct", + // Salamandra + "BSC-LT/salamandra-2b", + ] { + let vocabulary = Vocabulary::from_pretrained(model, None); + match vocabulary { + Ok(v) => { + assert!(v.eos_token_id().is_some()); + assert_eq!(v.eos_token_id, v.eos_token_id()); + assert!(!v.tokens.is_empty()); + } + Err(_) => unreachable!(), + } + } + } + + #[test] + fn pretrained_from_gpt2() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id; + assert_eq!(v_eos, vocabulary.eos_token_id()); + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); + assert_eq!(v_eos, 50256); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "<|endoftext|>" + ); + + let token = "Ġal"; + assert!(vocabulary.token_to_ids(token).is_none()); + assert!(tokenizer.token_to_id(token).is_some()); + + for (v_token, t_token_expected) in [("abc", "abc"), (" O", "ĠO")] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } + + #[test] + fn pretrained_from_llama() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); + + let v_eos = vocabulary.eos_token_id; + assert_eq!(v_eos, vocabulary.eos_token_id()); + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); + assert_eq!(v_eos, 2); + assert_eq!( + tokenizer.id_to_token(v_eos).expect("Token not found"), + "" + ); + + for (v_token, t_token_expected) in [ + ("abc", "abc"), + (" al", "▁al"), + (" O", "▁O"), + (" ", "▁▁▁"), + // TODO: won't pass since first we need to change token's type to bytes + // ("<0xFF>", "ÿ"), + // ("<0x20>", "▁"), + ] { + let v_ids = vocabulary.token_to_ids(v_token); + assert!(v_ids.is_some()); + for v_id in v_ids.unwrap() { + let t_token = tokenizer + .id_to_token(*v_id) + .expect("Token id not found in tokenizer"); + assert_eq!(&t_token, t_token_expected); + } + } + } + + #[test] + fn token_processor_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "Token processor"); + } + _ => unreachable!(), + } + } + + #[test] + fn tokenizer_error() { + let model = "hf-internal-testing/some-non-existent-model"; + let vocabulary = Vocabulary::from_pretrained(model, None); + + match vocabulary { + Err(Error::TokenizersError(e)) => assert!(!e.to_string().is_empty()), + _ => unreachable!(), + } + } + + struct NoneLocator; + impl Locator for NoneLocator { + fn locate_eos_token_id( + _model: &str, + _tokenizer: &Tokenizer, + _parameters: &Option, + ) -> Option { + None + } + } + + #[test] + fn unable_to_locate_eos_token_id_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained_with_locator::(model, None); + + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "EOS token id"); + } + _ => unreachable!(), + } + } + + #[test] + fn prepend_normalizers_filtered_out() { + use tokenizers::normalizers::{Prepend, Sequence}; + + let prepend = Prepend::new("_".to_string()); + let prepend_normalizer = NormalizerWrapper::Prepend(prepend); + let sequence = Sequence::new(vec![prepend_normalizer.clone()]); + let sequence_normalizer = NormalizerWrapper::Sequence(sequence); + + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + for normalizer in [prepend_normalizer, sequence_normalizer] { + let mut normalized_t = tokenizer.clone(); + normalized_t.with_normalizer(Some(normalizer)); + Vocabulary::filter_prepend_normalizers(&mut normalized_t); + if let Some(n) = normalized_t.get_normalizer() { + match n { + NormalizerWrapper::Sequence(seq) => { + for n in seq.get_normalizers() { + if let NormalizerWrapper::Prepend(_) = n { + unreachable!() + } + } + } + NormalizerWrapper::Prepend(_) => unreachable!(), + _ => {} + } + } + } + } + + #[test] + fn other_normalizers_being_kept() { + use tokenizers::normalizers::BertNormalizer; + + let model = "hf-internal-testing/llama-tokenizer"; + let normalizer = NormalizerWrapper::BertNormalizer(BertNormalizer::default()); + let mut tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + tokenizer.with_normalizer(Some(normalizer)); + + Vocabulary::filter_prepend_normalizers(&mut tokenizer); + + assert!(tokenizer.get_normalizer().is_some()); + } +} diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs new file mode 100644 index 00000000..7426f249 --- /dev/null +++ b/src/vocabulary/processor.rs @@ -0,0 +1,406 @@ +use std::collections::HashMap; + +use once_cell::sync::Lazy; +use serde::Deserialize; +use tokenizers::normalizers::Replace; +use tokenizers::{DecoderWrapper, Tokenizer}; + +use crate::{Error, Result}; + +/// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete +/// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each +/// byte to a valid UTF-8 character, `TokenProcessor` of `ByteFallback` level will be used +/// to map back these type of characters into bytes, based on `CHAR_MAP`. +/// +/// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O" +/// or +/// "Ġal" = [U+0120, U+0061, U+006C] should be interpreted as [0x20, 0x61, 0x6C] = " al" +/// +/// We'll use the following the mapping for this transition: +/// --- +/// 'Ā' == '\u{0100}' -> 0x00 == 0 +/// 'ā' == '\u{0101}' -> 0x01 == 1 +/// 'Ă' == '\u{0102}' -> 0x02 == 2 +/// ... +/// 'Ğ' == '\u{011E}' -> 0x1E == 30 +/// 'ğ' == '\u{011F}' -> 0x1F == 31 +/// 'Ġ' == '\u{0120}' -> 0x20 == 32 +/// --- +/// '!' == '\u{0021}' -> 0x21 == 33 +/// '"' == '\u{0022}' -> 0x22 == 34 +/// '#' == '\u{0023}' -> 0x23 == 35 +/// ... +/// '|' == '\u{007C}' -> 0x7C == 124 +/// '}' == '\u{007D}' -> 0x7D == 125 +/// '~' == '\u{007E}' -> 0x7E == 126 +/// --- +/// 'ġ' == '\u{0121}' -> 0x7F == 127 +/// 'Ģ' == '\u{0122}' -> 0x80 == 128 +/// 'ģ' == '\u{0123}' -> 0x81 == 129 +/// ... +/// 'ŀ' == '\u{0140}' -> 0x9E == 158 +/// 'Ł' == '\u{0141}' -> 0x9F == 159 +/// 'ł' == '\u{0142}' -> 0xA0 == 160 +/// --- +/// '¡' == '\u{00A1}' -> 0xA1 == 161 +/// '¢' == '\u{00A2}' -> 0xA2 == 162 +/// '£' == '\u{00A3}' -> 0xA3 == 163 +/// ... +/// 'ª' == '\u{00AA}' -> 0xAA == 170 +/// '«' == '\u{00AB}' -> 0xAB == 171 +/// '¬' == '\u{00AC}' -> 0xAC == 172 +/// --- +/// 'Ń' == '\u{0143}' -> 0xAD == 173 +/// --- +/// '®' == '\u{00AE}' -> 0xAE == 174 +/// '¯' == '\u{00AF}' -> 0xAF == 175 +/// '°' == '\u{00B0}' -> 0xB0 == 176 +/// ... +/// 'ý' == '\u{00FD}' -> 0xFD == 253 +/// 'þ' == '\u{00FE}' -> 0xFE == 254 +/// 'ÿ' == '\u{00FF}' -> 0xFF == 255 +/// --- +static CHAR_MAP: Lazy> = Lazy::new(|| { + let mut char_map = HashMap::with_capacity(256); + let mut key = 0x100u32; + for byte in 0..=255u8 { + let char = byte as char; + if matches!( + char, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}', + ) { + char_map.insert(char, byte); + } else if let Some(ch) = char::from_u32(key) { + char_map.insert(ch, byte); + key += 1; + } + } + char_map +}); + +/// Recognizes different tokenizer's levels. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum TokenProcessorLevel { + /// Matches byte level tokenizer (e.g., gpt2). + Byte, + /// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for + /// all __ >= 0x80 to represent incomplete UTF-8 sequences. + ByteFallback(Mods), +} + +/// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Mods { + spacechar: char, +} + +impl Default for Mods { + /// Default string modification to be applied by `TokenProcessor` of `ByteFallback` level. + fn default() -> Self { + Self { spacechar: ' ' } + } +} + +impl Mods { + /// Apply default modifications to each token. + fn apply_default(&self, token: String) -> String { + let to = Self::default().spacechar.to_string(); + token.replace(self.spacechar, &to) + } +} + +/// Local structure to be deserialized into from HF's `ReplaceDecoder` in order to get a replace pattern. +#[derive(Debug, Deserialize)] +struct ReplaceDecoder { + content: String, + pattern: ReplacePattern, +} + +impl ReplaceDecoder { + fn space_replacement(&self) -> Option { + if self.content != " " { + return None; + } + match &self.pattern { + ReplacePattern::String(pattern) => { + let mut chars = pattern.chars(); + let char = chars.next(); + if let Some(replacement) = char { + if chars.next().is_none() { + return Some(replacement); + } + } + None + } + } + } +} + +#[derive(Debug, Deserialize)] +enum ReplacePattern { + String(String), +} + +/// Token processor to adjust tokens according to the tokenizer's level. +#[derive(Debug)] +pub(crate) struct TokenProcessor { + level: TokenProcessorLevel, +} + +impl TokenProcessor { + /// Create new `TokenProcessor` with the level defined based on tokenizer's decoders. + pub(crate) fn new(tokenizer: &Tokenizer) -> Result { + match tokenizer.get_decoder() { + None => Err(Error::UnsupportedByTokenProcessor), + Some(decoder) => match decoder { + DecoderWrapper::ByteLevel(_) => Ok(Self { + level: TokenProcessorLevel::Byte, + }), + DecoderWrapper::Sequence(decoding_sequence) => { + let mut is_byte_fallback = false; + let mut spacechar = ' '; + + for decoder in decoding_sequence.get_decoders() { + match decoder { + DecoderWrapper::ByteFallback(_) => { + is_byte_fallback = true; + } + DecoderWrapper::Replace(replace) => { + // `Replace` decoder would replace a pattern in the output with something else, + // which we need to know. + let decoder = Self::unpack_decoder(replace)?; + if let Some(replacement) = decoder.space_replacement() { + spacechar = replacement; + } + } + _ => {} + } + } + + if is_byte_fallback { + Ok(Self { + level: TokenProcessorLevel::ByteFallback(Mods { spacechar }), + }) + } else { + Err(Error::UnsupportedByTokenProcessor) + } + } + _ => Err(Error::UnsupportedByTokenProcessor), + }, + } + } + + /// Operates on each token based on the level of `TokenProcessor`. + pub(crate) fn process(&self, token: String) -> Result> { + match &self.level { + TokenProcessorLevel::Byte => token + .chars() + .map(|char| { + CHAR_MAP + .get(&char) + .copied() + .ok_or(Error::ByteProcessorFailed) + }) + .collect(), + TokenProcessorLevel::ByteFallback(mods) => { + // If the token is of form `<0x__>`: + if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { + // Get to a single byte specified in the __ part and parse it in base 16 to a byte. + match u8::from_str_radix(&token[3..5], 16) { + Ok(byte) => Ok([byte].to_vec()), + Err(_) => Err(Error::ByteFallbackProcessorFailed), + } + } else { + Ok(mods.apply_default(token).as_bytes().to_vec()) + } + } + } + } + + /// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked + /// into local `ReplaceDecoder` structure. + #[cfg(not(tarpaulin_include))] + fn unpack_decoder(decoder: &Replace) -> Result { + match serde_json::to_value(decoder) { + Err(_) => Err(Error::DecoderUnpackingFailed), + Ok(value) => match serde_json::from_value(value) { + Ok(d) => Ok(d), + Err(_) => Err(Error::DecoderUnpackingFailed), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn byte_level_processor() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + assert_eq!(processor.level, TokenProcessorLevel::Byte); + + for (ch, byte) in [ + ('Ā', 0x00), + ('ā', 0x01), + ('Ă', 0x02), + ('Ğ', 0x1E), + ('ğ', 0x1F), + ('Ġ', 0x20), + ('!', 0x21), + ('"', 0x22), + ('#', 0x23), + ('|', 0x7C), + ('}', 0x7D), + ('~', 0x7E), + ('ġ', 0x7F), + ('Ģ', 0x80), + ('ģ', 0x81), + ('ŀ', 0x9E), + ('Ł', 0x9F), + ('ł', 0xA0), + ('¡', 0xA1), + ('¢', 0xA2), + ('£', 0xA3), + ('ª', 0xAA), + ('«', 0xAB), + ('¬', 0xAC), + ('Ń', 0xAD), + ('®', 0xAE), + ('¯', 0xAF), + ('°', 0xB0), + ('ý', 0xFD), + ('þ', 0xFE), + ('ÿ', 0xFF), + ] { + let processed = processor.process(ch.to_string()).expect("Not processed"); + assert_eq!(processed, [byte]); + } + } + + #[test] + fn byte_fallback_level_processor() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + let spacechar = '▁'; + let mods = Mods { spacechar }; + + assert_eq!(processor.level, TokenProcessorLevel::ByteFallback(mods)); + + for (input, expected) in [ + ("abc", vec![0x61, 0x62, 0x63]), + ("<0x61>", vec![0x61]), + ("<0x61>a", vec![0x3C, 0x30, 0x78, 0x36, 0x31, 0x3E, 0x61]), + (&spacechar.to_string(), vec![0x20]), + ( + &format!("{}{}abc", spacechar, spacechar), + vec![0x20, 0x20, 0x61, 0x62, 0x63], + ), + ( + &format!("{}{}{}", spacechar, spacechar, spacechar), + vec![0x20, 0x20, 0x20], + ), + ] { + let processed = processor.process(input.to_string()).expect("Not processed"); + assert_eq!(processed, expected); + } + } + + #[test] + fn unsupported_tokenizer_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + + let result = TokenProcessor::new(&tokenizer); + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), + } + } + + #[test] + fn byte_processor_error() { + let model = "openai-community/gpt2"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + for token in ["𝒜𝒷𝒸𝒟𝓔", "🦄🌈🌍🔥🎉", "京东购物"] { + let result = processor.process(token.to_string()); + match result { + Err(Error::ByteProcessorFailed) => {} + _ => unreachable!(), + } + } + } + + #[test] + fn byte_fallback_processor_error() { + let model = "hf-internal-testing/llama-tokenizer"; + let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); + + let result = processor.process("<0x6y>".to_string()); + match result { + Err(Error::ByteFallbackProcessorFailed) => {} + _ => unreachable!(), + } + } + + #[test] + fn only_get_spacechar_replacement() { + let one_char = "_".to_string(); + let pattern = ReplacePattern::String(one_char); + let not_spacechar = "-".to_string(); + let decoder = ReplaceDecoder { + content: not_spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn only_one_pattern_char_for_spacechar_replacement() { + let two_chars = "_*".to_string(); + let pattern = ReplacePattern::String(two_chars); + let spacechar = " ".to_string(); + let decoder = ReplaceDecoder { + content: spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn tokenizer_without_decoders_is_unsupported() { + use tokenizers::models::bpe::BPE; + + let tokenizer = Tokenizer::new(BPE::default()); + let result = TokenProcessor::new(&tokenizer); + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), + } + } + + #[test] + fn tokenizer_without_supported_decoders_in_sequence_is_unsupported() { + use tokenizers::decoders::sequence::Sequence; + use tokenizers::decoders::wordpiece::WordPiece; + use tokenizers::models::bpe::BPE; + + let mut tokenizer = Tokenizer::new(BPE::default()); + let decoder = WordPiece::default(); + let sequence = Sequence::new(vec![DecoderWrapper::WordPiece(decoder)]); + let decoder_sequence = DecoderWrapper::Sequence(sequence); + tokenizer.with_decoder(Some(decoder_sequence)); + + let result = TokenProcessor::new(&tokenizer); + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), + } + } +}