diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 4b417789..782b621a 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -89,15 +89,28 @@ struct EosTokenLocation { location: EosTokenField, } -/// Locates eos token id by searching in defined common locations. -pub(crate) fn locate_eos_token_id( - model: &str, - tokenizer: &Tokenizer, - parameters: &Option, -) -> Option { - COMMON_LOCATIONS - .iter() - .find_map(|location| location.lookup(model, tokenizer, parameters)) +/// Locates eos token id. +pub(crate) trait Locator { + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option; +} + +/// Locates eos token id by searching in defined common locations in hugging face. +pub(crate) struct HFLocator; + +impl Locator for HFLocator { + fn locate_eos_token_id( + model: &str, + tokenizer: &Tokenizer, + parameters: &Option, + ) -> Option { + COMMON_LOCATIONS + .iter() + .find_map(|location| location.lookup(model, tokenizer, parameters)) + } } impl EosTokenLocation { @@ -186,8 +199,8 @@ mod tests { ("hf-internal-testing/llama-tokenizer", 2, ""), ] { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); - let located = - locate_eos_token_id(model, &tokenizer, &None).expect("Token id is not located"); + let located = HFLocator::locate_eos_token_id(model, &tokenizer, &None) + .expect("Token id is not located"); assert_eq!(located, *expected_token_id); assert_eq!( diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index e3e24e51..eca4232d 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -6,6 +6,7 @@ use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; use crate::prelude::*; use crate::{Error, Result}; +use locator::{HFLocator, Locator}; use processor::TokenProcessor; mod locator; @@ -44,6 +45,15 @@ impl Vocabulary { pub fn from_pretrained( model: &str, parameters: Option, + ) -> Result { + Self::from_pretrained_with_locator::(model, parameters) + } + + #[doc(hidden)] + #[inline(always)] + fn from_pretrained_with_locator( + model: &str, + parameters: Option, ) -> Result { let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| { @@ -53,7 +63,7 @@ impl Vocabulary { })?; Self::filter_prepend_normalizers(&mut tokenizer); - let eos_token_id = locator::locate_eos_token_id(model, &tokenizer, ¶meters); + let eos_token_id = L::locate_eos_token_id(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { return Err(Error::UnableToLocateEosTokenId { model: model.to_string(), @@ -358,6 +368,28 @@ mod tests { } } + struct NoneLocator; + impl Locator for NoneLocator { + fn locate_eos_token_id( + _model: &str, + _tokenizer: &Tokenizer, + _parameters: &Option, + ) -> Option { + None + } + } + + #[test] + fn unable_to_locate_eos_token_id_error() { + let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM"; + let vocabulary = Vocabulary::from_pretrained_with_locator::(model, None); + + assert!(vocabulary.is_err()); + if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary { + assert_eq!(model, model.to_string()); + } + } + #[test] fn prepend_normalizers_filtered_out() { use tokenizers::normalizers::{Prepend, Sequence};