Skip to content

Commit

Permalink
Introduce eos token locator
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 5, 2024
1 parent 86b7d86 commit a27d8e2
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ thiserror = "1.0"
pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true }
regex = "1.10.6"
serde-pyobject = "0.4.0"
serde_json = { version = "1.0.125", features = ["preserve_order"] }
serde_json = { version = "1.0", features = ["preserve_order"] }
serde = {version = "1", features = ["derive"]}
hf-hub = "0.3.2"
tokenizers = { version = "0.20.0", features = ["http"] }

[features]
python-bindings = ["pyo3"]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub mod primitives;
pub mod regex;
pub mod vocabulary;

mod locator;

#[cfg(feature = "python-bindings")]
mod python_bindings;

Expand Down
215 changes: 215 additions & 0 deletions src/locator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use serde::{Deserialize, Serialize};
use tokenizers::{FromPretrainedParameters, Tokenizer};

use crate::primitives::*;

/// List of common eos token locations appearing on hugging face hub, ordered by priority.
const COMMON_LOCATIONS: &[EosTokenLocation] = &[
// Most projects have `generation_config.json` that looks like:
// {
// ...
// "eos_token_id": 50256,
// ...
// }
// So it's the first place we look for the eos token id.
//
// For example:
// - https://huggingface.co/openai-community/gpt2/blob/main/generation_config.json
EosTokenLocation {
file: "generation_config.json",
location: EosTokenField::Id,
},
// The ones that don't have `generation_config.json` usually have `tokenizer_config.json`:
// {
// ...
// "eos_token": "<|endoftext|>",
// ...
// }
// Once we have the eos token content, we can get its id from the tokenizer.
//
// For example:
// - https://huggingface.co/microsoft/phi-2/blob/main/tokenizer_config.json
EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Value,
},
// Sometimes `tokenizer_config.json` can have the following format as well:
// {
// "eos_token": {
// ...
// "content": "</s>",
// ...
// },
// }
// Once we have the eos token content, we can get its id from the tokenizer.
//
// For example:
// - https://huggingface.co/hf-internal-testing/llama-tokenizer/blob/main/tokenizer_config.json
EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Object,
},
];

#[derive(Debug, Serialize, Deserialize)]
struct Id {
eos_token_id: u64,
}

#[derive(Debug, Serialize, Deserialize)]
struct Value {
eos_token: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Object {
eos_token: Content,
}

#[derive(Debug, Serialize, Deserialize)]
struct Content {
content: String,
}

/// Kind of the json field which will be checked for eos token id.
enum EosTokenField {
Id,
Value,
Object,
}

/// Location of the end of sentence token id in a config file.
struct EosTokenLocation {
file: &'static str,
location: EosTokenField,
}

pub(crate) struct EosTokenLocator;

impl EosTokenLocator {
pub(crate) fn locate(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
}
}

impl EosTokenLocation {
/// Finds eos token within defined location in related config file.
fn lookup(
&self,
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
let file_path = Self::download_config(model, self.file, parameters).ok()?;
let file = std::fs::File::open(file_path).ok()?;

match self.location {
EosTokenField::Id => {
let config: Id = serde_json::from_reader(file).ok()?;
u32::try_from(config.eos_token_id).ok()
}
EosTokenField::Value => {
let config: Value = serde_json::from_reader(file).ok()?;
tokenizer.token_to_id(&config.eos_token)
}
EosTokenField::Object => {
let config: Object = serde_json::from_reader(file).ok()?;
tokenizer.token_to_id(&config.eos_token.content)
}
}
}

/// Downloads a config file from Hugging Face Hub.
fn download_config(
project: &str,
file: &str,
parameters: &Option<FromPretrainedParameters>,
) -> tokenizers::Result<std::path::PathBuf> {
// Adapted from
// https://github.com/huggingface/tokenizers/blob/9b77c054ef4297c7057fa8db875368c7c02f1bfc/tokenizers/src/utils/from_pretrained.rs#L26

let params = parameters.clone().unwrap_or_default();

Self::validate(project)?;
Self::validate(&params.revision)?;

let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision);
let api = ApiBuilder::new()
.with_token(params.auth_token)
.build()?
.repo(repo);

Ok(api.get(file)?)
}

fn validate(input: &str) -> tokenizers::Result<()> {
let valid_chars = ['-', '_', '.', '/'];

if !input
.chars()
.all(|c: char| c.is_alphanumeric() || valid_chars.contains(&c))
{
return Err(format!(
"Input {input} contains invalid characters, expected only alphanumeric or {}",
valid_chars
.iter()
.map(|x| format!("'{}'", x))
.collect::<Vec<_>>()
.join(", ")
)
.into());
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn common_locations() {
for (model, expected_token_id, expected_token) in &[
("openai-community/gpt2", 50256, "<|endoftext|>"),
("microsoft/phi-2", 50256, "<|endoftext|>"),
("hf-internal-testing/llama-tokenizer", 2, "</s>"),
] {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let located =
EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located");

assert_eq!(located, *expected_token_id);
assert_eq!(
tokenizer.id_to_token(located).expect("Token is not found"),
expected_token.to_string()
);
}
}

#[test]
fn bad_location() {
let bad_location = EosTokenLocation {
file: "tokenizer_config.json",
location: EosTokenField::Id,
};
let model = "microsoft/phi-2";
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");

let token_id = bad_location.lookup(model, &tokenizer, &None);
assert!(token_id.is_none());

let bad_file = EosTokenLocation {
file: "generation_config.json",
location: EosTokenField::Value,
};
let token_id = bad_file.lookup(model, &tokenizer, &None);
assert!(token_id.is_none());
}
}

0 comments on commit a27d8e2

Please sign in to comment.