Skip to content

Commit

Permalink
Separate tokenizers errors, test supported pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 18, 2024
1 parent 5e0177a commit 50c4225
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 62 deletions.
21 changes: 19 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
use thiserror::Error;

#[derive(Error, Debug)]
pub struct TokenizersError(pub tokenizers::Error);

impl PartialEq for TokenizersError {
fn eq(&self, other: &Self) -> bool {
self.0.to_string() == other.0.to_string()
}
}

impl std::fmt::Display for TokenizersError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("The vocabulary does not allow us to build a sequence that matches the input")]
IndexError,
#[error("Unable to create tokenizer for {model}")]
UnableToCreateTokenizer { model: String },
#[error(transparent)]
TokenizersError(#[from] TokenizersError),
#[error("Unsupported tokenizer for {model}: {reason}, feel free to open an issue: https://github.com/dottxt-ai/outlines-core/issues")]
UnsupportedTokenizer { model: String, reason: String },
#[error("Unable to locate EOS token for {model}")]
UnableToLocateEosTokenId { model: String },
#[error("Tokenizer is not supported by token processor")]
Expand Down
121 changes: 76 additions & 45 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashMap;
use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

use crate::prelude::*;
use crate::{error, prelude::*};
use crate::{Error, Result};

use locator::{HFLocator, Locator};
Expand Down Expand Up @@ -55,22 +55,44 @@ impl Vocabulary {
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self> {
let mut tokenizer =
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| {
Error::UnableToCreateTokenizer {
model: model.to_string(),
}
})?;
let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone())
.map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?;
Self::filter_prepend_normalizers(&mut tokenizer);

// Locate eos_token_id in defined locations.
let eos_token_id = L::locate_eos_token_id(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
return Err(Error::UnableToLocateEosTokenId {
return Err(Error::UnsupportedTokenizer {
model: model.to_string(),
reason: "EOS token id".to_string(),
});
};

Vocabulary::try_from((tokenizer, eos_token_id))
// Start building the vocabulary from eos_token_id and added tokens.
let mut vocabulary = Vocabulary::new(Some(eos_token_id));
for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() {
if !added_token.special {
vocabulary = vocabulary.insert(added_token.content.clone(), *id);
}
}

// Process each vocabulary token according to the tokenizer's level.
let Ok(processor) = TokenProcessor::new(&tokenizer) else {
return Err(Error::UnsupportedTokenizer {
model: model.to_string(),
reason: "Token processor".to_string(),
});
};
for (token, token_id) in tokenizer.get_vocab(false) {
let token_bytes = processor.process(token)?;
// TODO: lossy is temp:
// - in python in was handled by byte_symbol function
// - interface needs to be redefined to treat Token type as bytes: Vec<u8>
let processed_token = String::from_utf8_lossy(&token_bytes);
vocabulary = vocabulary.insert(processed_token, token_id);
}

Ok(vocabulary)
}

/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
Expand Down Expand Up @@ -114,33 +136,6 @@ impl Vocabulary {
}
}

impl TryFrom<(Tokenizer, u32)> for Vocabulary {
type Error = Error;

fn try_from(value: (Tokenizer, u32)) -> Result<Self> {
let (tokenizer, eos_token_id) = value;

let mut vocabulary = Vocabulary::new(Some(eos_token_id));
for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() {
if !added_token.special {
vocabulary = vocabulary.insert(added_token.content.clone(), *id);
}
}

let processor = TokenProcessor::new(&tokenizer)?;
for (token, token_id) in tokenizer.get_vocab(false) {
let token_bytes = processor.process(token)?;
// TODO: lossy is temp:
// - in python in was handled by byte_symbol function
// - interface needs to be redefined to treat Token type as bytes: Vec<u8>
let processed_token = String::from_utf8_lossy(&token_bytes);
vocabulary = vocabulary.insert(processed_token, token_id);
}

Ok(vocabulary)
}
}

impl Vocabulary {
/// Inserts a token to the vocabulary with the specified identifier.
pub fn insert(mut self, token: impl Into<Token>, id: TokenId) -> Vocabulary {
Expand Down Expand Up @@ -278,13 +273,42 @@ mod tests {
assert!(!vocabulary.tokens.is_empty());
}

#[test]
fn supported_pretrained_models() {
// Support is expected for these:
for model in [
// GPT 2
"openai-community/gpt2",
// Llama 2
"hf-internal-testing/Llama-2-7B-GPTQ",
// Llama 3
// OpenCoder: shares llama tokenizers
"hf-internal-testing/llama-3-8b-internal",
// Qwen
"Qwen/Qwen2-7B-Instruct",
// Salamandra
"BSC-LT/salamandra-2b",
] {
let vocabulary = Vocabulary::from_pretrained(model, None);
match vocabulary {
Ok(v) => {
assert!(v.eos_token_id().is_some());
assert_eq!(v.eos_token_id, v.eos_token_id());
assert!(!v.tokens.is_empty());
}
Err(_) => unreachable!(),
}
}
}

#[test]
fn pretrained_from_gpt2() {
let model = "openai-community/gpt2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id;
assert_eq!(v_eos, vocabulary.eos_token_id());
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
Expand Down Expand Up @@ -317,6 +341,7 @@ mod tests {
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id;
assert_eq!(v_eos, vocabulary.eos_token_id());
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
Expand Down Expand Up @@ -351,9 +376,12 @@ mod tests {
let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM";
let vocabulary = Vocabulary::from_pretrained(model, None);

assert!(vocabulary.is_err());
if let Err(e) = vocabulary {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
match vocabulary {
Err(Error::UnsupportedTokenizer { model, reason }) => {
assert_eq!(model, model.to_string());
assert_eq!(&reason, "Token processor");
}
_ => unreachable!(),
}
}

Expand All @@ -362,9 +390,9 @@ mod tests {
let model = "hf-internal-testing/some-non-existent-model";
let vocabulary = Vocabulary::from_pretrained(model, None);

assert!(vocabulary.is_err());
if let Err(Error::UnableToCreateTokenizer { model }) = vocabulary {
assert_eq!(model, model.to_string());
match vocabulary {
Err(Error::TokenizersError(e)) => assert!(!e.to_string().is_empty()),
_ => unreachable!(),
}
}

Expand All @@ -384,9 +412,12 @@ mod tests {
let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM";
let vocabulary = Vocabulary::from_pretrained_with_locator::<NoneLocator>(model, None);

assert!(vocabulary.is_err());
if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary {
assert_eq!(model, model.to_string());
match vocabulary {
Err(Error::UnsupportedTokenizer { model, reason }) => {
assert_eq!(model, model.to_string());
assert_eq!(&reason, "EOS token id");
}
_ => unreachable!(),
}
}

Expand Down
30 changes: 15 additions & 15 deletions src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ mod tests {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");

let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
match result {
Err(Error::UnsupportedByTokenProcessor) => {}
_ => unreachable!(),
}
}

Expand All @@ -325,9 +325,9 @@ mod tests {

for token in ["π’œπ’·π’Έπ’Ÿπ“”", "πŸ¦„πŸŒˆπŸŒπŸ”₯πŸŽ‰", "δΊ¬δΈœθ΄­η‰©"] {
let result = processor.process(token.to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::ByteProcessorFailed)
match result {
Err(Error::ByteProcessorFailed) => {}
_ => unreachable!(),
}
}
}
Expand All @@ -339,9 +339,9 @@ mod tests {
let processor = TokenProcessor::new(&tokenizer).expect("Processor failed");

let result = processor.process("<0x6y>".to_string());
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::ByteFallbackProcessorFailed)
match result {
Err(Error::ByteFallbackProcessorFailed) => {}
_ => unreachable!(),
}
}

Expand Down Expand Up @@ -375,9 +375,9 @@ mod tests {

let tokenizer = Tokenizer::new(BPE::default());
let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
match result {
Err(Error::UnsupportedByTokenProcessor) => {}
_ => unreachable!(),
}
}

Expand All @@ -394,9 +394,9 @@ mod tests {
tokenizer.with_decoder(Some(decoder_sequence));

let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
match result {
Err(Error::UnsupportedByTokenProcessor) => {}
_ => unreachable!(),
}
}
}

0 comments on commit 50c4225

Please sign in to comment.