From d09ea690f6f81d04dceda70c4ce99b69014664c2 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Mon, 11 Nov 2024 20:54:45 +0000 Subject: [PATCH] Improve documentation and visibilities --- src/vocabulary/locator.rs | 12 ++++++++---- src/vocabulary/mod.rs | 11 ++++++----- src/vocabulary/processor.rs | 23 +++++++++++++---------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/vocabulary/locator.rs b/src/vocabulary/locator.rs index 61cc581e..b3c6f068 100644 --- a/src/vocabulary/locator.rs +++ b/src/vocabulary/locator.rs @@ -52,16 +52,20 @@ const COMMON_LOCATIONS: &[EosTokenLocation] = &[ }, ]; +/// `Id` kind of `EosTokenField`, when `eos_token_id` provided as an id. #[derive(Debug, Serialize, Deserialize)] struct Id { eos_token_id: u64, } +/// `Value` kind of `EosTokenField`, when `eos_token` provided as a text, so that its id +/// will be fetched from the tokenizer. #[derive(Debug, Serialize, Deserialize)] struct Value { eos_token: String, } +/// `Object` kind of `EosTokenField`, when `eos_token` provided as a `Content`. #[derive(Debug, Serialize, Deserialize)] struct Object { eos_token: Content, @@ -72,14 +76,14 @@ struct Content { content: String, } -/// Which part in config's json to check for eos token id. +/// Specifies in which part in config's json to check for eos token id. enum EosTokenField { Id, Value, Object, } -/// Location of the end of sentence token id in a config file. +/// Defines location of the end of sentence token id in the config file. struct EosTokenLocation { file: &'static str, location: EosTokenField, @@ -101,7 +105,7 @@ impl EosTokenLocator { } impl EosTokenLocation { - /// Finds eos token within defined location in related config file. + /// Finds eos token within defined location in a related config file. fn lookup( &self, model: &str, @@ -127,7 +131,7 @@ impl EosTokenLocation { } } - /// Downloads a config file from Hugging Face Hub. + /// Downloads related config file from Hugging Face Hub. fn download_config( project: &str, file: &str, diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index fef311f4..b62c22e7 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -52,7 +52,7 @@ impl Vocabulary { source: TokenizerError(error), } })?; - Self::filter_normalizers(&mut tokenizer); + Self::filter_prepend_normalizers(&mut tokenizer); let eos_token_id = EosTokenLocator::locate(model, &tokenizer, ¶meters); let Some(eos_token_id) = eos_token_id else { @@ -64,17 +64,18 @@ impl Vocabulary { Vocabulary::try_from((&mut tokenizer, eos_token_id)) } - /// Per provided token returns vector of `TokenId`s if available in vocabulary. + /// Per provided token returns vector of `TokenId`s if available in the vocabulary. pub fn token_to_ids(&self, token: &str) -> Option<&Vec> { self.map.get(token) } - /// Gets the identifier of the special end of sentence token. + /// Gets the identifier of the special end of the sentence token. pub fn eos_token_id(&self) -> Option { self.eos_token_id } - fn filter_normalizers(tokenizer: &mut Tokenizer) { + /// Filters out `Prepend` kind of tokenizer's normalizers. + fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) { // Main concern is prepend normalizers, for example https://github.com/google/sentencepiece // In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text, // e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.] @@ -348,7 +349,7 @@ mod tests { for normalizer in [prepend_normalizer, sequence_normalizer] { let mut normalized_t = tokenizer.clone(); normalized_t.with_normalizer(Some(normalizer)); - Vocabulary::filter_normalizers(&mut normalized_t); + Vocabulary::filter_prepend_normalizers(&mut normalized_t); if let Some(n) = normalized_t.get_normalizer() { match n { NormalizerWrapper::Sequence(seq) => { diff --git a/src/vocabulary/processor.rs b/src/vocabulary/processor.rs index ce149f80..9488a78f 100644 --- a/src/vocabulary/processor.rs +++ b/src/vocabulary/processor.rs @@ -7,11 +7,12 @@ use tokenizers::{DecoderWrapper, Tokenizer}; use crate::TokenProcessorError; -pub type Result = std::result::Result; +type Result = std::result::Result; /// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete -/// utf-8 characters. For example, b` \xf0` can be one token. These tokenizers map each -/// byte to a valid UTF-8 character. And we need to map back those characters into bytes. +/// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each +/// byte to a valid UTF-8 character, `TokenProcessor` of `ByteFallback` level will be used +/// to map back these type of characters into bytes, based on `CHAR_MAP`. /// /// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O" /// or @@ -84,9 +85,9 @@ pub(crate) struct TokenProcessor { level: TokenProcessorLevel, } -/// Recognized tokenizer's levels. +/// Recognizes different tokenizer's levels. #[derive(Debug, Clone, PartialEq)] -pub enum TokenProcessorLevel { +pub(crate) enum TokenProcessorLevel { /// Matches byte level tokenizer (e.g., gpt2). Byte, /// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for @@ -103,9 +104,9 @@ impl std::fmt::Display for TokenProcessorLevel { } } -/// Modifications to be applied by `ByteFallback` `TokenProcessorLevel`. +/// Modifications to be applied by `TokenProcessor`of `ByteFallback` level. #[derive(Debug, Clone, PartialEq)] -pub struct Mods { +pub(crate) struct Mods { spacechar: char, } @@ -120,6 +121,7 @@ impl Mods { } } +/// Local structure to be deserialized into from HF's `ReplaceDecoder` in order to get a replace pattern. #[derive(Debug, Deserialize)] struct ReplaceDecoder { content: String, @@ -147,7 +149,7 @@ impl ReplaceDecoder { } #[derive(Debug, Deserialize)] -pub enum ReplacePattern { +enum ReplacePattern { String(String), } @@ -194,7 +196,7 @@ impl TokenProcessor { } } - /// Process each token based on the level of `TokenProcessor`. + /// Operates on each token based on the level of `TokenProcessor`. pub(crate) fn process(&self, token: String) -> Result> { match &self.level { TokenProcessorLevel::Byte => { @@ -222,7 +224,8 @@ impl TokenProcessor { } } - /// Since all fields of `Replace` are private with no getters, we'll have to unpack it into our own. + /// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked + /// into local `ReplaceDecoder` structure. fn unpack_decoder(decoder: &Replace) -> Result { match serde_json::to_value(decoder) { Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),