-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
221 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; | ||
use serde::{Deserialize, Serialize}; | ||
use tokenizers::{FromPretrainedParameters, Tokenizer}; | ||
|
||
use crate::primitives::*; | ||
|
||
/// List of common eos token locations appearing on hugging face hub, ordered by priority. | ||
const COMMON_LOCATIONS: &[EosTokenLocation] = &[ | ||
// Most projects have `generation_config.json` that looks like: | ||
// { | ||
// ... | ||
// "eos_token_id": 50256, | ||
// ... | ||
// } | ||
// So it's the first place we look for the eos token id. | ||
// | ||
// For example: | ||
// - https://huggingface.co/openai-community/gpt2/blob/main/generation_config.json | ||
EosTokenLocation { | ||
file: "generation_config.json", | ||
location: EosTokenField::Id, | ||
}, | ||
// The ones that don't have `generation_config.json` usually have `tokenizer_config.json`: | ||
// { | ||
// ... | ||
// "eos_token": "<|endoftext|>", | ||
// ... | ||
// } | ||
// Once we have the eos token content, we can get its id from the tokenizer. | ||
// | ||
// For example: | ||
// - https://huggingface.co/microsoft/phi-2/blob/main/tokenizer_config.json | ||
EosTokenLocation { | ||
file: "tokenizer_config.json", | ||
location: EosTokenField::Value, | ||
}, | ||
// Sometimes `tokenizer_config.json` can have the following format as well: | ||
// { | ||
// "eos_token": { | ||
// ... | ||
// "content": "</s>", | ||
// ... | ||
// }, | ||
// } | ||
// Once we have the eos token content, we can get its id from the tokenizer. | ||
// | ||
// For example: | ||
// - https://huggingface.co/hf-internal-testing/llama-tokenizer/blob/main/tokenizer_config.json | ||
EosTokenLocation { | ||
file: "tokenizer_config.json", | ||
location: EosTokenField::Object, | ||
}, | ||
]; | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Id { | ||
eos_token_id: u64, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Value { | ||
eos_token: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Object { | ||
eos_token: Content, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
struct Content { | ||
content: String, | ||
} | ||
|
||
/// Kind of the json field which will be checked for eos token id. | ||
enum EosTokenField { | ||
Id, | ||
Value, | ||
Object, | ||
} | ||
|
||
/// Location of the end of sentence token id in a config file. | ||
struct EosTokenLocation { | ||
file: &'static str, | ||
location: EosTokenField, | ||
} | ||
|
||
pub(crate) struct EosTokenLocator; | ||
|
||
impl EosTokenLocator { | ||
pub(crate) fn locate( | ||
model: &str, | ||
tokenizer: &Tokenizer, | ||
parameters: &Option<FromPretrainedParameters>, | ||
) -> Option<TokenId> { | ||
COMMON_LOCATIONS | ||
.iter() | ||
.find_map(|location| location.lookup(model, tokenizer, parameters)) | ||
} | ||
} | ||
|
||
impl EosTokenLocation { | ||
/// Finds eos token within defined location in related config file. | ||
fn lookup( | ||
&self, | ||
model: &str, | ||
tokenizer: &Tokenizer, | ||
parameters: &Option<FromPretrainedParameters>, | ||
) -> Option<TokenId> { | ||
let file_path = Self::download_config(model, self.file, parameters).ok()?; | ||
let file = std::fs::File::open(file_path).ok()?; | ||
|
||
match self.location { | ||
EosTokenField::Id => { | ||
let config: Id = serde_json::from_reader(file).ok()?; | ||
u32::try_from(config.eos_token_id).ok() | ||
} | ||
EosTokenField::Value => { | ||
let config: Value = serde_json::from_reader(file).ok()?; | ||
tokenizer.token_to_id(&config.eos_token) | ||
} | ||
EosTokenField::Object => { | ||
let config: Object = serde_json::from_reader(file).ok()?; | ||
tokenizer.token_to_id(&config.eos_token.content) | ||
} | ||
} | ||
} | ||
|
||
/// Downloads a config file from Hugging Face Hub. | ||
fn download_config( | ||
project: &str, | ||
file: &str, | ||
parameters: &Option<FromPretrainedParameters>, | ||
) -> tokenizers::Result<std::path::PathBuf> { | ||
// Adapted from | ||
// https://github.com/huggingface/tokenizers/blob/9b77c054ef4297c7057fa8db875368c7c02f1bfc/tokenizers/src/utils/from_pretrained.rs#L26 | ||
|
||
let params = parameters.clone().unwrap_or_default(); | ||
|
||
Self::validate(project)?; | ||
Self::validate(¶ms.revision)?; | ||
|
||
let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision); | ||
let api = ApiBuilder::new() | ||
.with_token(params.auth_token) | ||
.build()? | ||
.repo(repo); | ||
|
||
Ok(api.get(file)?) | ||
} | ||
|
||
fn validate(input: &str) -> tokenizers::Result<()> { | ||
let valid_chars = ['-', '_', '.', '/']; | ||
|
||
if !input | ||
.chars() | ||
.all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c)) | ||
{ | ||
return Err(format!( | ||
"Input {input} contains invalid characters, expected only alphanumeric or {}", | ||
valid_chars | ||
.iter() | ||
.map(|x| format!("'{}'", x)) | ||
.collect::<Vec<_>>() | ||
.join(", ") | ||
) | ||
.into()); | ||
} | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn common_locations() { | ||
for (model, expected_token_id, expected_token) in &[ | ||
("openai-community/gpt2", 50256, "<|endoftext|>"), | ||
("microsoft/phi-2", 50256, "<|endoftext|>"), | ||
("hf-internal-testing/llama-tokenizer", 2, "</s>"), | ||
] { | ||
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); | ||
let located = | ||
EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located"); | ||
|
||
assert_eq!(located, *expected_token_id); | ||
assert_eq!( | ||
tokenizer.id_to_token(located).expect("Token is not found"), | ||
expected_token.to_string() | ||
); | ||
} | ||
} | ||
|
||
#[test] | ||
fn bad_location() { | ||
let bad_location = EosTokenLocation { | ||
file: "tokenizer_config.json", | ||
location: EosTokenField::Id, | ||
}; | ||
let model = "microsoft/phi-2"; | ||
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed"); | ||
|
||
let token_id = bad_location.lookup(model, &tokenizer, &None); | ||
assert!(token_id.is_none()); | ||
|
||
let bad_file = EosTokenLocation { | ||
file: "generation_config.json", | ||
location: EosTokenField::Value, | ||
}; | ||
let token_id = bad_file.lookup(model, &tokenizer, &None); | ||
assert!(token_id.is_none()); | ||
} | ||
} |