diff --git a/src/error.rs b/src/error.rs index 53a8728..4ffe7ed 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,9 +6,6 @@ pub type Result = std::result::Result; pub enum Error { #[error("The vocabulary does not allow to build an index that matches the input")] InsufficientVocabulary, - // TODO: this error will be removed once eos_token_id for vocabulary won't be optional - #[error("Index failed since vocabulary doesn't provide eos token id")] - IndexEosTokenIdNotAvailable, #[error("Failed to build DFA {0}")] IndexDfaError(#[from] Box), #[error("Index failed since anchored universal start state doesn't exist")] diff --git a/src/index.rs b/src/index.rs index fc35aa7..e8d839d 100644 --- a/src/index.rs +++ b/src/index.rs @@ -18,12 +18,7 @@ pub struct Index { impl Index { pub(crate) fn new(regex: &str, vocabulary: &Vocabulary) -> Result { - let eos_token_id = match vocabulary.eos_token_id() { - Some(s) => s, - // TODO: this error will be removed once eos_token_id for vocabulary won't be optional - None => return Err(Error::IndexEosTokenIdNotAvailable), - }; - + let eos_token_id = vocabulary.eos_token_id(); let dfa = DFA::new(regex).map_err(Box::new)?; let start_state = match dfa.universal_start_state(Anchored::Yes) { Some(s) => s, @@ -135,7 +130,7 @@ mod tests { #[test] fn index_from_regex() { let regex = "0|[1-9][0-9]*"; - let vocabulary = Vocabulary::new(Some(4)) + let vocabulary = Vocabulary::new(4) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -157,7 +152,7 @@ mod tests { #[test] fn index_from_regex_initital_in_allowed() { let regex = "`\\n(\\.\\n)?`\\n"; - let vocabulary = Vocabulary::new(Some(104)) + let vocabulary = Vocabulary::new(104) .insert("\n", 103) .insert(".", 102) .insert("`", 101); @@ -172,7 +167,7 @@ mod tests { #[test] fn index_from_regex_multibyte() { let regex = "😇| [😈-😍][😇-😎]*"; - let vocabulary = Vocabulary::new(Some(8)) + let vocabulary = Vocabulary::new(8) .insert(" 😍", 5) .insert("blah", 0) .insert("😇", 2) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index bfe7b28..657032c 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -182,7 +182,7 @@ impl PyVocabulary { Ok(PyVocabulary(v)) } - fn get_eos_token_id(&self) -> Option { + fn get_eos_token_id(&self) -> TokenId { self.0.eos_token_id() } diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 5488ac6..c1dd1ea 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -29,27 +29,19 @@ mod processor; /// ``` #[derive(Clone, Debug, Default, PartialEq, Encode, Decode)] pub struct Vocabulary { - // TODO: Option is temp for back compatibility - eos_token_id: Option, + eos_token_id: TokenId, tokens: HashMap>, } impl Vocabulary { /// Creates an empty vocabulary. - pub fn new(eos_token_id: Option) -> Self { + pub fn new(eos_token_id: TokenId) -> Self { Self { eos_token_id, tokens: HashMap::default(), } } - pub fn with_eos_token_id(self, eos_token_id: Option) -> Self { - Self { - eos_token_id, - ..self - } - } - /// Creates the vocabulary of pre-trained model from Hugging Face Hub. pub fn from_pretrained( model: &str, @@ -77,7 +69,7 @@ impl Vocabulary { }; // Start building the vocabulary from eos_token_id and added tokens. - let mut vocabulary = Vocabulary::new(Some(eos_token_id)); + let mut vocabulary = Vocabulary::new(eos_token_id); for (id, added_token) in tokenizer.get_added_tokens_decoder().iter() { if !added_token.special { vocabulary = vocabulary.insert(added_token.content.clone(), *id); @@ -110,7 +102,7 @@ impl Vocabulary { } /// Gets the identifier of the special end of the sentence token. - pub fn eos_token_id(&self) -> Option { + pub fn eos_token_id(&self) -> TokenId { self.eos_token_id } @@ -207,7 +199,7 @@ impl From<(TokenId, HashMap>)> for Vocabulary { fn from(values: (TokenId, HashMap>)) -> Vocabulary { let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: Some(eos_token_id), + eos_token_id, tokens, } } @@ -217,7 +209,7 @@ impl From<(TokenId, HashMap>)> for Vocabulary { fn from(values: (TokenId, HashMap>)) -> Vocabulary { let (eos_token_id, tokens) = values; Vocabulary { - eos_token_id: Some(eos_token_id), + eos_token_id, tokens: tokens .into_iter() .map(|(k, v)| (k.as_bytes().to_vec(), v)) @@ -226,16 +218,6 @@ impl From<(TokenId, HashMap>)> for Vocabulary { } } -impl FromIterator<(T, I)> for Vocabulary -where - T: Into, - I: IntoIterator, -{ - fn from_iter>(tokens_and_ids: A) -> Self { - Vocabulary::new(None).extend(tokens_and_ids) - } -} - #[cfg(test)] mod tests { use super::*; @@ -243,7 +225,7 @@ mod tests { #[test] fn insert() { - let vocabulary = Vocabulary::new(None) + let vocabulary = Vocabulary::new(4) .insert("blah", 0) .insert("1a", 1) .insert("2", 2) @@ -258,7 +240,7 @@ mod tests { #[test] fn extend() { - let vocabulary = Vocabulary::new(None).extend([ + let vocabulary = Vocabulary::new(4).extend([ ("blah", vec![0]), ("1a", vec![1]), ("2", vec![2]), @@ -274,28 +256,19 @@ mod tests { #[test] fn new_empty_vocabulary() { - let vocabulary = Vocabulary::new(None); - assert!(vocabulary.eos_token_id.is_none()); + let vocabulary = Vocabulary::new(1); + assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } #[test] fn new_empty_vocabulary_from_hashmap() { - let vocabulary = Vocabulary::new(None); - assert!(vocabulary.eos_token_id.is_none()); + let map: HashMap> = HashMap::default(); + let vocabulary = Vocabulary::from((1_u32, map)); + assert_eq!(vocabulary.eos_token_id, 1); assert!(vocabulary.tokens.is_empty()); } - #[test] - fn new_vocabulary_from_iterator() { - let token: Token = "abc".as_bytes().to_vec(); - let id: Vec = vec![1]; - let it = vec![(token, id)]; - let vocabulary = Vocabulary::from_iter(it); - assert!(vocabulary.eos_token_id.is_none()); - assert!(!vocabulary.tokens.is_empty()); - } - #[test] fn supported_pretrained_models() { // Support is expected for these: @@ -315,7 +288,6 @@ mod tests { let vocabulary = Vocabulary::from_pretrained(model, None); match vocabulary { Ok(v) => { - assert!(v.eos_token_id().is_some()); assert_eq!(v.eos_token_id, v.eos_token_id()); assert!(!v.tokens.is_empty()); } @@ -332,9 +304,6 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 50256); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), @@ -366,9 +335,6 @@ mod tests { let v_eos = vocabulary.eos_token_id; assert_eq!(v_eos, vocabulary.eos_token_id()); - assert!(v_eos.is_some()); - - let v_eos = v_eos.unwrap(); assert_eq!(v_eos, 2); assert_eq!( tokenizer.id_to_token(v_eos).expect("Token not found"), diff --git a/tests/fsm/test_vocabulary.py b/tests/fsm/test_vocabulary.py index 764e11c..a714eae 100644 --- a/tests/fsm/test_vocabulary.py +++ b/tests/fsm/test_vocabulary.py @@ -1,18 +1,20 @@ import pickle -import pytest +import pytest from outlines_core.fsm import Vocabulary + def test_supports_strings_as_keys(): eos_token_id = 3 tokens = {"1": [1], "a": [2]} vocabulary = Vocabulary.from_dict(eos_token_id, tokens) - + assert vocabulary.get_eos_token_id() == eos_token_id assert vocabulary.get("1") == [1] assert vocabulary.get(b"1") == [1] assert len(vocabulary) == 2 + def test_supports_bytes_as_keys(): eos_token_id = 3 tokens = {b"1": [1], b"a": [2]} @@ -23,16 +25,17 @@ def test_supports_bytes_as_keys(): assert vocabulary.get("1") == [1] assert len(vocabulary) == 2 + def test_do_not_supports_other_types_as_keys(): eos_token_id = 3 tokens = {1: [1], 2: [2]} with pytest.raises( - TypeError, - match="Expected a dictionary with keys of type String or Bytes" + TypeError, match="Expected a dictionary with keys of type String or Bytes" ): Vocabulary.from_dict(eos_token_id, tokens) + def test_pickling(): eos_token_id = 3 tokens = {"1": [1], "a": [2]}