Skip to content

Commit

Permalink
Locator as a trait
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 14, 2024
1 parent 2420742 commit 5e0177a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
35 changes: 24 additions & 11 deletions src/vocabulary/locator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,28 @@ struct EosTokenLocation {
location: EosTokenField,
}

/// Locates eos token id by searching in defined common locations.
pub(crate) fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
/// Locates eos token id.
pub(crate) trait Locator {
fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId>;
}

/// Locates eos token id by searching in defined common locations in hugging face.
pub(crate) struct HFLocator;

impl Locator for HFLocator {
fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
}
}

impl EosTokenLocation {
Expand Down Expand Up @@ -186,8 +199,8 @@ mod tests {
("hf-internal-testing/llama-tokenizer", 2, "</s>"),
] {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let located =
locate_eos_token_id(model, &tokenizer, &None).expect("Token id is not located");
let located = HFLocator::locate_eos_token_id(model, &tokenizer, &None)
.expect("Token id is not located");

assert_eq!(located, *expected_token_id);
assert_eq!(
Expand Down
34 changes: 33 additions & 1 deletion src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
use crate::prelude::*;
use crate::{Error, Result};

use locator::{HFLocator, Locator};
use processor::TokenProcessor;

mod locator;
Expand Down Expand Up @@ -44,6 +45,15 @@ impl Vocabulary {
pub fn from_pretrained(
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self> {
Self::from_pretrained_with_locator::<HFLocator>(model, parameters)
}

#[doc(hidden)]
#[inline(always)]
fn from_pretrained_with_locator<L: Locator>(
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self> {
let mut tokenizer =
Tokenizer::from_pretrained(model, parameters.clone()).map_err(|_| {
Expand All @@ -53,7 +63,7 @@ impl Vocabulary {
})?;
Self::filter_prepend_normalizers(&mut tokenizer);

let eos_token_id = locator::locate_eos_token_id(model, &tokenizer, &parameters);
let eos_token_id = L::locate_eos_token_id(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
return Err(Error::UnableToLocateEosTokenId {
model: model.to_string(),
Expand Down Expand Up @@ -358,6 +368,28 @@ mod tests {
}
}

struct NoneLocator;
impl Locator for NoneLocator {
fn locate_eos_token_id(
_model: &str,
_tokenizer: &Tokenizer,
_parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
None
}
}

#[test]
fn unable_to_locate_eos_token_id_error() {
let model = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM";
let vocabulary = Vocabulary::from_pretrained_with_locator::<NoneLocator>(model, None);

assert!(vocabulary.is_err());
if let Err(Error::UnableToLocateEosTokenId { model }) = vocabulary {
assert_eq!(model, model.to_string());
}
}

#[test]
fn prepend_normalizers_filtered_out() {
use tokenizers::normalizers::{Prepend, Sequence};
Expand Down

0 comments on commit 5e0177a

Please sign in to comment.