From 2420742cfe73db3764352f26942332920906af11 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Wed, 13 Nov 2024 19:20:09 +0000 Subject: [PATCH] Improve test coverage --- Cargo.toml | 3 ++ src/vocabulary/locator.rs | 9 +++++ src/vocabulary/mod.rs | 49 ++++++++++++++++++++++++++-- src/vocabulary/processor.rs | 65 ++++++++++++++++++++++++++++++++----- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 947620f1..94eab3a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,3 +36,6 @@ panic = 'abort' [package.metadata.scripts] build-python-extension = "python setup.py build_rust --inplace --debug" build-python-extension-release = "python setup.py build_rust --inplace --release" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tarpaulin_include)'] } diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 31a3548c..4b417789 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -138,6 +138,9 @@ impl EosTokenLocation { let params = parameters.clone().unwrap_or_default(); + // Validation checks are coming as a literal adaptation logic from HF. + // In this case project is a model name, which if invalid expected to fail much earlier. + // So it seems a bit redundant to validate it this way, but no harm in doing so too. Self::validate(project)?; Self::validate(¶ms.revision)?; @@ -213,4 +216,10 @@ mod tests { let token_id = bad_file.lookup(model, &tokenizer, &None); assert!(token_id.is_none()); } + + #[test] + fn validate_config_input() { + let input = "bad_model_name*"; + assert!(EosTokenLocation::validate(input).is_err()); + } } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 7048731b..e3e24e51 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -243,13 +243,41 @@ mod tests { assert_eq!(vocabulary["0"], &[3]); } + #[test] + fn new_empty_vocabulary() { + let vocabulary = Vocabulary::new(None); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_empty_vocabulary_from_hashmap() { + let map = HashMap::new(); + let vocabulary = Vocabulary::from(map); + assert!(vocabulary.eos_token_id.is_none()); + assert!(vocabulary.tokens.is_empty()); + } + + #[test] + fn new_vocabulary_from_iterator() { + let token: Token = "abc".to_string(); + let id: Vec = vec![1]; + let it = vec![(token, id)]; + let vocabulary = Vocabulary::from_iter(it); + assert!(vocabulary.eos_token_id.is_none()); + assert!(!vocabulary.tokens.is_empty()); + } + #[test] fn pretrained_from_gpt2() { let model = "openai-community/gpt2"; let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); - let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + let v_eos = vocabulary.eos_token_id; + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 50256); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -278,7 +306,10 @@ mod tests { let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); let vocabulary = Vocabulary::from_pretrained(model, None).expect("Vocabulary failed"); - let v_eos = vocabulary.eos_token_id.expect("No eos token in vocabulary"); + let v_eos = vocabulary.eos_token_id; + assert!(v_eos.is_some()); + + let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 2); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -358,4 +389,18 @@ mod tests { } } } + + #[test] + fn other_normalizers_being_kept() { + use tokenizers::normalizers::BertNormalizer; + + let model = "hf-internal-testing/llama-tokenizer"; + let normalizer = NormalizerWrapper::BertNormalizer(BertNormalizer::default()); + let mut tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); + tokenizer.with_normalizer(Some(normalizer)); + + Vocabulary::filter_prepend_normalizers(&mut tokenizer); + + assert!(tokenizer.get_normalizer().is_some()); + } } diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index cec32f52..74d700d2 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -93,15 +93,6 @@ pub(crate) enum TokenProcessorLevel { ByteFallback(Mods), } -impl std::fmt::Display for TokenProcessorLevel { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::Byte => write!(f, "Byte Level"), - Self::ByteFallback(mods) => write!(f, "Byte Fallback Level with mods: {:?}", mods), - } - } -} - /// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. #[derive(Debug, Clone, PartialEq)] pub(crate) struct Mods { @@ -223,6 +214,7 @@ impl TokenProcessor { /// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked /// into local `ReplaceDecoder` structure. + #[cfg(not(tarpaulin_include))] fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { Err(_) => Err(Error::DecoderUnpackingFailed), @@ -352,4 +344,59 @@ mod tests { assert_eq!(e, Error::ByteFallbackProcessorFailed) } } + + #[test] + fn only_get_spacechar_replacement() { + let one_char = "_".to_string(); + let pattern = ReplacePattern::String(one_char); + let not_spacechar = "-".to_string(); + let decoder = ReplaceDecoder { + content: not_spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn only_one_pattern_char_for_spacechar_replacement() { + let two_chars = "_*".to_string(); + let pattern = ReplacePattern::String(two_chars); + let spacechar = " ".to_string(); + let decoder = ReplaceDecoder { + content: spacechar, + pattern, + }; + assert!(decoder.space_replacement().is_none()); + } + + #[test] + fn tokenizer_without_decoders_is_unsupported() { + use tokenizers::models::bpe::BPE; + + let tokenizer = Tokenizer::new(BPE::default()); + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, Error::UnsupportedByTokenProcessor) + } + } + + #[test] + fn tokenizer_without_supported_decoders_in_sequence_is_unsupported() { + use tokenizers::decoders::sequence::Sequence; + use tokenizers::decoders::wordpiece::WordPiece; + use tokenizers::models::bpe::BPE; + + let mut tokenizer = Tokenizer::new(BPE::default()); + let decoder = WordPiece::default(); + let sequence = Sequence::new(vec![DecoderWrapper::WordPiece(decoder)]); + let decoder_sequence = DecoderWrapper::Sequence(sequence); + tokenizer.with_decoder(Some(decoder_sequence)); + + let result = TokenProcessor::new(&tokenizer); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!(e, Error::UnsupportedByTokenProcessor) + } + } }