Skip to content

Commit

Permalink
Improve documentation and visibilities
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 11, 2024
1 parent 9b42797 commit d09ea69
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
12 changes: 8 additions & 4 deletions src/vocabulary/locator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,20 @@ const COMMON_LOCATIONS: &[EosTokenLocation] = &[
},
];

/// `Id` kind of `EosTokenField`, when `eos_token_id` provided as an id.
#[derive(Debug, Serialize, Deserialize)]
struct Id {
eos_token_id: u64,
}

/// `Value` kind of `EosTokenField`, when `eos_token` provided as a text, so that its id
/// will be fetched from the tokenizer.
#[derive(Debug, Serialize, Deserialize)]
struct Value {
eos_token: String,
}

/// `Object` kind of `EosTokenField`, when `eos_token` provided as a `Content`.
#[derive(Debug, Serialize, Deserialize)]
struct Object {
eos_token: Content,
Expand All @@ -72,14 +76,14 @@ struct Content {
content: String,
}

/// Which part in config's json to check for eos token id.
/// Specifies in which part in config's json to check for eos token id.
enum EosTokenField {
Id,
Value,
Object,
}

/// Location of the end of sentence token id in a config file.
/// Defines location of the end of sentence token id in the config file.
struct EosTokenLocation {
file: &'static str,
location: EosTokenField,
Expand All @@ -101,7 +105,7 @@ impl EosTokenLocator {
}

impl EosTokenLocation {
/// Finds eos token within defined location in related config file.
/// Finds eos token within defined location in a related config file.
fn lookup(
&self,
model: &str,
Expand All @@ -127,7 +131,7 @@ impl EosTokenLocation {
}
}

/// Downloads a config file from Hugging Face Hub.
/// Downloads related config file from Hugging Face Hub.
fn download_config(
project: &str,
file: &str,
Expand Down
11 changes: 6 additions & 5 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Vocabulary {
source: TokenizerError(error),
}
})?;
Self::filter_normalizers(&mut tokenizer);
Self::filter_prepend_normalizers(&mut tokenizer);

let eos_token_id = EosTokenLocator::locate(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
Expand All @@ -64,17 +64,18 @@ impl Vocabulary {
Vocabulary::try_from((&mut tokenizer, eos_token_id))
}

/// Per provided token returns vector of `TokenId`s if available in vocabulary.
/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
pub fn token_to_ids(&self, token: &str) -> Option<&Vec<TokenId>> {
self.map.get(token)
}

/// Gets the identifier of the special end of sentence token.
/// Gets the identifier of the special end of the sentence token.
pub fn eos_token_id(&self) -> Option<TokenId> {
self.eos_token_id
}

fn filter_normalizers(tokenizer: &mut Tokenizer) {
/// Filters out `Prepend` kind of tokenizer's normalizers.
fn filter_prepend_normalizers(tokenizer: &mut Tokenizer) {
// Main concern is prepend normalizers, for example https://github.com/google/sentencepiece
// In `sentencepiece` tokenizer, `▁` is used to denote spaces in the source text,
// e.g. `Hello World.` could be tokenized as: [Hello] [▁Wor] [ld] [.]
Expand Down Expand Up @@ -348,7 +349,7 @@ mod tests {
for normalizer in [prepend_normalizer, sequence_normalizer] {
let mut normalized_t = tokenizer.clone();
normalized_t.with_normalizer(Some(normalizer));
Vocabulary::filter_normalizers(&mut normalized_t);
Vocabulary::filter_prepend_normalizers(&mut normalized_t);
if let Some(n) = normalized_t.get_normalizer() {
match n {
NormalizerWrapper::Sequence(seq) => {
Expand Down
23 changes: 13 additions & 10 deletions src/vocabulary/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ use tokenizers::{DecoderWrapper, Tokenizer};

use crate::TokenProcessorError;

pub type Result<T, E = TokenProcessorError> = std::result::Result<T, E>;
type Result<T, E = TokenProcessorError> = std::result::Result<T, E>;

/// GPT2-like tokenizers have multibyte tokens that can have a mix of full and incomplete
/// utf-8 characters. For example, b` \xf0` can be one token. These tokenizers map each
/// byte to a valid UTF-8 character. And we need to map back those characters into bytes.
/// UTF-8 characters, for example, byte ` \xf0` can be one token. These tokenizers map each
/// byte to a valid UTF-8 character, `TokenProcessor` of `ByteFallback` level will be used
/// to map back these type of characters into bytes, based on `CHAR_MAP`.
///
/// "ĠO" = [U+0120, U+004F] should be interpreted as [0x20, 0x4F] = " O"
/// or
Expand Down Expand Up @@ -84,9 +85,9 @@ pub(crate) struct TokenProcessor {
level: TokenProcessorLevel,
}

/// Recognized tokenizer's levels.
/// Recognizes different tokenizer's levels.
#[derive(Debug, Clone, PartialEq)]
pub enum TokenProcessorLevel {
pub(crate) enum TokenProcessorLevel {
/// Matches byte level tokenizer (e.g., gpt2).
Byte,
/// Matches byte fallback tokenizer (e.g., llama), which have <0x__> tokens for
Expand All @@ -103,9 +104,9 @@ impl std::fmt::Display for TokenProcessorLevel {
}
}

/// Modifications to be applied by `ByteFallback` `TokenProcessorLevel`.
/// Modifications to be applied by `TokenProcessor`of `ByteFallback` level.
#[derive(Debug, Clone, PartialEq)]
pub struct Mods {
pub(crate) struct Mods {
spacechar: char,
}

Expand All @@ -120,6 +121,7 @@ impl Mods {
}
}

/// Local structure to be deserialized into from HF's `ReplaceDecoder` in order to get a replace pattern.
#[derive(Debug, Deserialize)]
struct ReplaceDecoder {
content: String,
Expand Down Expand Up @@ -147,7 +149,7 @@ impl ReplaceDecoder {
}

#[derive(Debug, Deserialize)]
pub enum ReplacePattern {
enum ReplacePattern {
String(String),
}

Expand Down Expand Up @@ -194,7 +196,7 @@ impl TokenProcessor {
}
}

/// Process each token based on the level of `TokenProcessor`.
/// Operates on each token based on the level of `TokenProcessor`.
pub(crate) fn process(&self, token: String) -> Result<Vec<u8>> {
match &self.level {
TokenProcessorLevel::Byte => {
Expand Down Expand Up @@ -222,7 +224,8 @@ impl TokenProcessor {
}
}

/// Since all fields of `Replace` are private with no getters, we'll have to unpack it into our own.
/// Since all fields of HF's `Replace` are private with no getters, it needs to be unpacked
/// into local `ReplaceDecoder` structure.
fn unpack_decoder(decoder: &Replace) -> Result<ReplaceDecoder> {
match serde_json::to_value(decoder) {
Err(_) => Err(TokenProcessorError::DecoderUnpackingFailed),
Expand Down

0 comments on commit d09ea69

Please sign in to comment.