Skip to content

Commit

Permalink
Apply suggestions from CR
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Nov 19, 2024
1 parent 7de9e28 commit 105fd18
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ repository = "https://github.com/dottxt-ai/outlines-core"
[dependencies]
once_cell = "1.20"
anyhow = "1.0.86"
thiserror = "1.0"
thiserror = "2.0"
pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true }
regex = "1.10.6"
serde-pyobject = "0.4.0"
serde_json = { version = "1.0", features = ["preserve_order"] }
serde = {version = "1", features = ["derive"]}
serde = {version = "1.0", features = ["derive"]}
# Fragile dependencies, minor updates often break the code
hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.0", features = ["http"] }
tokenizers = { version = "=0.20.3", features = ["http"] }

[features]
python-bindings = ["pyo3"]
Expand Down
26 changes: 11 additions & 15 deletions src/vocabulary/locator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,15 @@ struct EosTokenLocation {
location: EosTokenField,
}

pub(crate) struct EosTokenLocator;

impl EosTokenLocator {
/// Locates eos token id by searching in defined common locations.
pub(crate) fn locate(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
}
/// Locates eos token id by searching in defined common locations.
pub(crate) fn locate_eos_token_id(
model: &str,
tokenizer: &Tokenizer,
parameters: &Option<FromPretrainedParameters>,
) -> Option<TokenId> {
COMMON_LOCATIONS
.iter()
.find_map(|location| location.lookup(model, tokenizer, parameters))
}

impl EosTokenLocation {
Expand Down Expand Up @@ -147,7 +143,7 @@ impl EosTokenLocation {

let repo = Repo::with_revision(project.to_string(), RepoType::Model, params.revision);
let api = ApiBuilder::new()
.with_token(params.auth_token)
.with_token(params.token)
.build()?
.repo(repo);

Expand Down Expand Up @@ -188,7 +184,7 @@ mod tests {
] {
let tokenizer = Tokenizer::from_pretrained(model, None).expect("Tokenizer failed");
let located =
EosTokenLocator::locate(model, &tokenizer, &None).expect("Token id is not located");
locate_eos_token_id(model, &tokenizer, &None).expect("Token id is not located");

assert_eq!(located, *expected_token_id);
assert_eq!(
Expand Down
27 changes: 13 additions & 14 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
use crate::prelude::*;
use crate::{Error, Result};

use locator::EosTokenLocator;
use processor::TokenProcessor;

mod locator;
Expand All @@ -29,15 +28,15 @@ mod processor;
pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
map: HashMap<Token, Vec<TokenId>>,
tokens: HashMap<Token, Vec<TokenId>>,
}

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new(eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
map: HashMap::new(),
tokens: HashMap::new(),
}
}

Expand All @@ -54,19 +53,19 @@ impl Vocabulary {
})?;
Self::filter_prepend_normalizers(&mut tokenizer);

let eos_token_id = EosTokenLocator::locate(model, &tokenizer, &parameters);
let eos_token_id = locator::locate_eos_token_id(model, &tokenizer, &parameters);
let Some(eos_token_id) = eos_token_id else {
return Err(Error::UnableToLocateEosTokenId {
model: model.to_string(),
});
};

Vocabulary::try_from((&mut tokenizer, eos_token_id))
Vocabulary::try_from((tokenizer, eos_token_id))
}

/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
pub fn token_to_ids(&self, token: &str) -> Option<&Vec<TokenId>> {
self.map.get(token)
self.tokens.get(token)
}

/// Gets the identifier of the special end of the sentence token.
Expand Down Expand Up @@ -105,10 +104,10 @@ impl Vocabulary {
}
}

impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary {
impl TryFrom<(Tokenizer, u32)> for Vocabulary {
type Error = Error;

fn try_from(value: (&mut Tokenizer, u32)) -> Result<Self> {
fn try_from(value: (Tokenizer, u32)) -> Result<Self> {
let (tokenizer, eos_token_id) = value;

let mut vocabulary = Vocabulary::new(Some(eos_token_id));
Expand All @@ -118,7 +117,7 @@ impl TryFrom<(&mut Tokenizer, u32)> for Vocabulary {
}
}

let processor = TokenProcessor::new(tokenizer)?;
let processor = TokenProcessor::new(&tokenizer)?;
for (token, token_id) in tokenizer.get_vocab(false) {
let token_bytes = processor.process(token)?;
// TODO: lossy is temp:
Expand Down Expand Up @@ -154,7 +153,7 @@ impl Vocabulary {
pub fn insert_in_place(&mut self, token: impl Into<Token>, id: TokenId) {
// TODO: return error if eos token id is inserted
let token = token.into();
self.map.entry(token).or_default().push(id);
self.tokens.entry(token).or_default().push(id);
}

/// Extends the vocabulary with tokens and their identifiers, in place.
Expand All @@ -164,7 +163,7 @@ impl Vocabulary {
) {
for (token, ids) in tokens_and_ids.into_iter() {
let token = token.into();
self.map.entry(token).or_default().extend(ids);
self.tokens.entry(token).or_default().extend(ids);
}
}
}
Expand All @@ -173,7 +172,7 @@ impl std::ops::Deref for Vocabulary {
type Target = HashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.map
&self.tokens
}
}

Expand All @@ -191,10 +190,10 @@ impl std::fmt::Display for Vocabulary {
}

impl From<HashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(map: HashMap<Token, Vec<TokenId>>) -> Vocabulary {
fn from(tokens: HashMap<Token, Vec<TokenId>>) -> Vocabulary {
Vocabulary {
eos_token_id: None,
map,
tokens,
}
}
}
Expand Down

0 comments on commit 105fd18

Please sign in to comment.