diff --git a/src/error.rs b/src/error.rs index ff977a7..3eec176 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,28 @@ use thiserror::Error; +#[derive(Error, Debug)] +pub struct TokenizersError(pub tokenizers::Error); + +impl PartialEq for TokenizersError { + fn eq(&self, other: &Self) -> bool { + self.0.to_string() == other.0.to_string() + } +} + +impl std::fmt::Display for TokenizersError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + #[derive(Error, Debug, PartialEq)] pub enum Error { #[error("The vocabulary does not allow us to build a sequence that matches the input")] IndexError, - #[error("Unable to create tokenizer for {model}")] - UnableToCreateTokenizer { model: String }, + #[error(transparent)] + TokenizersError(#[from] TokenizersError), + #[error("Unsupported tokenizer for {model}: {reason}, feel free to open an issue: https://github.com/dottxt-ai/outlines-core/issues")] + UnsupportedTokenizer { model: String, reason: String }, #[error("Unable to locate EOS token for {model}")] UnableToLocateEosTokenId { model: String }, #[error("Tokenizer is not supported by token processor")] diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index eca4232..b4bcc4b 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; -use crate::prelude::*; +use crate::{error, prelude::*}; use crate::{Error, Result}; use locator::{HFLocator, Locator}; @@ -55,22 +55,44 @@ impl Vocabulary { model: &str, parameters: Option, ) -> Result { - let mut tokenizer = - Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { - Error::UnableToCreateTokenizer { - model: model.to_string(), - } - })?; + let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()) + .map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?; Self::filter_prepend_normalizers(&mut tokenizer); + // Locate eos_token_id in defined locations. let eos_token_id = L::locate_eos_token_id(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { - return Err(Error::UnableToLocateEosTokenId { + return Err(Error::UnsupportedTokenizer { model: model.to_string(), + reason: "EOS token id".to_string(), }); }; - Vocabulary::try_from((tokenizer, eos_token_id)) + // Start building the vocabulary from eos_token_id and added tokens. + let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { + if !added_token.special { + vocabulary = vocabulary.insert(added_token.content.clone(), *id); + } + } + + // Process each vocabulary token according to the tokenizer's level. + let Ok(processor) = TokenProcessor::new(&tokenizer) else { + return Err(Error::UnsupportedTokenizer { + model: model.to_string(), + reason: "Token processor".to_string(), + }); + }; + for (token, token_id) in tokenizer.get_vocab(false) { + let token_bytes = processor.process(token)?; + // TODO: lossy is temp: + // - in python in was handled by byte_symbol function + // - interface needs to be redefined to treat Token type as bytes: Vec + let processed_token = String::from_utf8_lossy(&token_bytes); + vocabulary = vocabulary.insert(processed_token, token_id); + } + + Ok(vocabulary) } /// Per provided token returns vector of `TokenId`s if available in the vocabulary. @@ -114,33 +136,6 @@ impl Vocabulary { } } -impl TryFrom<(Tokenizer, u32)> for Vocabulary { - type Error = Error; - - fn try_from(value: (Tokenizer, u32)) -> Result { - let (tokenizer, eos_token_id) = value; - - let mut vocabulary = Vocabulary::new(Some(eos_token_id)); - for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { - if !added_token.special { - vocabulary = vocabulary.insert(added_token.content.clone(), *id); - } - } - - let processor = TokenProcessor::new(&tokenizer)?; - for (token, token_id) in tokenizer.get_vocab(false) { - let token_bytes = processor.process(token)?; - // TODO: lossy is temp: - // - in python in was handled by byte_symbol function - // - interface needs to be redefined to treat Token type as bytes: Vec - let processed_token = String::from_utf8_lossy(&token_bytes); - vocabulary = vocabulary.insert(processed_token, token_id); - } - - Ok(vocabulary) - } -} - impl Vocabulary { /// Inserts a token to the vocabulary with the specified identifier. pub fn insert(mut self, token: impl Into, id: TokenId) -> Vocabulary { @@ -278,6 +273,27 @@ mod tests { assert!(!vocabulary.tokens.is_empty()); } + #[test] + fn supported_pretrained_models() { + // Support is expected for these: + for model in [ + // GPT 2 + "openai-community/gpt2", + // Llama 2 + "hf-internal-testing/Llama-2-7B-GPTQ", + // Llama 3 + // OpenCoder: shares llama tokenizers + "hf-internal-testing/llama-3-8b-internal", + // Qwen + "Qwen/Qwen2-7B-Instruct", + // Salamandra + "BSC-LT/salamandra-2b", + ] { + let vocabulary = Vocabulary::from_pretrained(model, None); + assert!(vocabulary.is_ok()) + } + } + #[test] fn pretrained_from_gpt2() { let model = "openai-community/gpt2"; @@ -351,9 +367,12 @@ mod tests { let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; let vocabulary = Vocabulary::from_pretrained(model, None); - assert!(vocabulary.is_err()); - if let Err(e) = vocabulary { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "Token processor"); + } + _ => unreachable!(), } } @@ -362,9 +381,9 @@ mod tests { let model = "hf-internal-testing/some-non-existent-model"; let vocabulary = Vocabulary::from_pretrained(model, None); - assert!(vocabulary.is_err()); - if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary { - assert_eq!(model, model.to_string()); + match vocabulary { + Err(Error::TokenizersError(e)) => assert!(!e.to_string().is_empty()), + _ => unreachable!(), } } @@ -384,9 +403,12 @@ mod tests { let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; let vocabulary = Vocabulary::from_pretrained_with_locator::(model, None); - assert!(vocabulary.is_err()); - if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary { - assert_eq!(model, model.to_string()); + match vocabulary { + Err(Error::UnsupportedTokenizer { model, reason }) => { + assert_eq!(model, model.to_string()); + assert_eq!(&reason, "EOS token id"); + } + _ => unreachable!(), } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index 74d700d..55b6cde 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -311,9 +311,9 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } @@ -325,9 +325,9 @@ mod tests { for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] { let result = processor.process(token.to_string()); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::ByteProcessorFailed) + match result { + Err(Error::ByteProcessorFailed) => {} + _ => unreachable!(), } } } @@ -339,9 +339,9 @@ mod tests { let processor = TokenProcessor::new(&tokenizer).expect("Processor failed"); let result = processor.process("<0x6y>".to_string()); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::ByteFallbackProcessorFailed) + match result { + Err(Error::ByteFallbackProcessorFailed) => {} + _ => unreachable!(), } } @@ -375,9 +375,9 @@ mod tests { let tokenizer = Tokenizer::new(BPE::default()); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } @@ -394,9 +394,9 @@ mod tests { tokenizer.with_decoder(Some(decoder_sequence)); let result = TokenProcessor::new(&tokenizer); - assert!(result.is_err()); - if let Err(e) = result { - assert_eq!(e, Error::UnsupportedByTokenProcessor) + match result { + Err(Error::UnsupportedByTokenProcessor) => {} + _ => unreachable!(), } } }