Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 19, 2024
1 parent 04e2491 commit 2e748aa
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 11 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ panic = 'abort'
[package.metadata.scripts]
build-python-extension = "python setup.py build_rust --inplace --debug"
build-python-extension-release = "python setup.py build_rust --inplace --release"

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] }
9 changes: 9 additions & 0 deletions src/vocabulary/locator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ impl EosTokenLocation {

let params = parameters.clone().unwrap_or_default();

// Validation checks are coming as a literal adaptation logic from HF.
// In this case project is a model name, which if invalid expected to fail much earlier.
// So it seems a bit redundant to validate it this way, but no harm in doing so too.
Self::validate(project)?;
Self::validate(&params.revision)?;

Expand Down Expand Up @@ -213,4 +216,10 @@ mod tests {
let token_id = bad_file.lookup(model, &tokenizer, &None);
assert!(token_id.is_none());
}

#[test]
fn validate_config_input() {
let input = "bad_model_name*";
assert!(EosTokenLocation::validate(input).is_err());
}
}
49 changes: 47 additions & 2 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,41 @@ mod tests {
assert_eq!(vocabulary["0"], &[3]);
}

#[test]
fn new_empty_vocabulary() {
let vocabulary = Vocabulary::new(None);
assert!(vocabulary.eos_token_id.is_none());
assert!(vocabulary.tokens.is_empty());
}

#[test]
fn new_empty_vocabulary_from_hashmap() {
let map = HashMap::new();
let vocabulary = Vocabulary::from(map);
assert!(vocabulary.eos_token_id.is_none());
assert!(vocabulary.tokens.is_empty());
}

#[test]
fn new_vocabulary_from_iterator() {
let token: Token = "abc".to_string();
let id: Vec<TokenId> = vec![1];
let it = vec![(token, id)];
let vocabulary = Vocabulary::from_iter(it);
assert!(vocabulary.eos_token_id.is_none());
assert!(!vocabulary.tokens.is_empty());
}

#[test]
fn pretrained_from_gpt2() {
let model = "openai-community/gpt2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary");
let v_eos = vocabulary.eos_token_id;
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
assert_eq!(v_eos, 50256);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
Expand Down Expand Up @@ -278,7 +306,10 @@ mod tests {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed");

let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary");
let v_eos = vocabulary.eos_token_id;
assert!(v_eos.is_some());

let v_eos = v_eos.unwrap();
assert_eq!(v_eos, 2);
assert_eq!(
tokenizer.id_to_token(v_eos).expect("Token not found"),
Expand Down Expand Up @@ -358,4 +389,18 @@ mod tests {
}
}
}

#[test]
fn other_normalizers_being_kept() {
use tokenizers::normalizers::BertNormalizer;

let model = "hf-internal-testing/llama-tokenizer";
let normalizer = NormalizerWrapper::BertNormalizer(BertNormalizer::default());
let mut tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
tokenizer.with_normalizer(Some(normalizer));

Vocabulary::filter_prepend_normalizers(&mut tokenizer);

assert!(tokenizer.get_normalizer().is_some());
}
}
65 changes: 56 additions & 9 deletions src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,6 @@ pub(crate) enum TokenProcessorLevel {
ByteFallback(Mods),
}

impl std::fmt::Display for TokenProcessorLevel {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Byte => write!(f, "Byte Level"),
Self::ByteFallback(mods) => write!(f, "Byte Fallback Level with mods: {:?}", mods),
}
}
}

/// Modifications to be applied by `TokenProcessor`of `ByteFallback` level.
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct Mods {
Expand Down Expand Up @@ -223,6 +214,7 @@ impl TokenProcessor {

/// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked
/// into local `ReplaceDecoder` structure.
#[cfg(not(tarpaulin_include))]
fn unpack_decoder(decoder: &Replace) -> Result<ReplaceDecoder> {
match serde_json::to_value(decoder) {
Err(_) => Err(Error::DecoderUnpackingFailed),
Expand Down Expand Up @@ -352,4 +344,59 @@ mod tests {
assert_eq!(e, Error::ByteFallbackProcessorFailed)
}
}

#[test]
fn only_get_spacechar_replacement() {
let one_char = "_".to_string();
let pattern = ReplacePattern::String(one_char);
let not_spacechar = "-".to_string();
let decoder = ReplaceDecoder {
content: not_spacechar,
pattern,
};
assert!(decoder.space_replacement().is_none());
}

#[test]
fn only_one_pattern_char_for_spacechar_replacement() {
let two_chars = "_*".to_string();
let pattern = ReplacePattern::String(two_chars);
let spacechar = " ".to_string();
let decoder = ReplaceDecoder {
content: spacechar,
pattern,
};
assert!(decoder.space_replacement().is_none());
}

#[test]
fn tokenizer_without_decoders_is_unsupported() {
use tokenizers::models::bpe::BPE;

let tokenizer = Tokenizer::new(BPE::default());
let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
}
}

#[test]
fn tokenizer_without_supported_decoders_in_sequence_is_unsupported() {
use tokenizers::decoders::sequence::Sequence;
use tokenizers::decoders::wordpiece::WordPiece;
use tokenizers::models::bpe::BPE;

let mut tokenizer = Tokenizer::new(BPE::default());
let decoder = WordPiece::default();
let sequence = Sequence::new(vec![DecoderWrapper::WordPiece(decoder)]);
let decoder_sequence = DecoderWrapper::Sequence(sequence);
tokenizer.with_decoder(Some(decoder_sequence));

let result = TokenProcessor::new(&tokenizer);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e, Error::UnsupportedByTokenProcessor)
}
}
}

0 comments on commit 2e748aa

Please sign in to comment.