diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index f970980b..4a913d40 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -33,7 +33,7 @@ def state_scan_tokens( fsm_initial: int, fsm_finals: Set[int], vocabulary: List[Tuple[str, List[int]]], - vocabulary_transition_keys: List[List[int]], + vocabulary_transition_keys: Dict[str, List[int]], start_state: int, ) -> Set[Tuple[int, int]]: ... def get_token_transition_keys( @@ -46,7 +46,7 @@ def get_vocabulary_transition_keys( alphabet_anything_value: int, vocabulary: List[Tuple[str, List[int]]], frozen_tokens: Set[str], -) -> List[List[int]]: ... +) -> Dict[str, List[int]]: ... def create_fsm_index_end_to_end( fsm_info: FSMInfo, vocabulary: List[Tuple[str, List[int]]], diff --git a/src/lib.rs b/src/lib.rs index 5811ff7a..5c7b632f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,5 +5,9 @@ pub mod regex; mod python_bindings; mod primitives; +pub use primitives::{State, Token, TokenId, TransitionKey}; -pub use crate::primitives::{State, TokenId, TransitionKey}; +mod vocabulary; +pub use vocabulary::Vocabulary; + +pub(crate) use {std::collections::HashMap, std::ops::Deref}; diff --git a/src/primitives.rs b/src/primitives.rs index bbc77700..e12bf036 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -1,6 +1,9 @@ /// Interegular transition key. pub type TransitionKey = u32; +/// Token content. +pub type Token = String; + /// Token identifier. pub type TokenId = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 251300f6..5017b1de 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -94,9 +94,10 @@ pub fn state_scan_tokens_py( fsm_initial: State, fsm_finals: HashSet, vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, + vocabulary_transition_keys: HashMap>, start_state: State, ) -> PyResult> { + let vocabulary = Vocabulary::from_iter(vocabulary); Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -130,7 +131,8 @@ pub fn get_vocabulary_transition_keys_py( alphabet_anything_value: TransitionKey, vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, -) -> PyResult>> { +) -> PyResult>> { + let vocabulary = Vocabulary::from_iter(vocabulary); Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -147,6 +149,8 @@ pub fn create_fsm_index_end_to_end_py<'py>( vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult> { + let vocabulary = Vocabulary::from_iter(vocabulary); + let states_to_token_subsets = PyDict::new_bound(py); let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); diff --git a/src/regex.rs b/src/regex.rs index 314ff0cf..aac6467c 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -42,17 +42,14 @@ pub fn state_scan_tokens( fsm_transitions: &HashMap<(State, TransitionKey), State>, fsm_initial: State, fsm_finals: &HashSet, - vocabulary: &[(String, Vec)], - vocabulary_transition_keys: &[Vec], + vocabulary: &Vocabulary, + vocabulary_transition_keys: &HashMap>, start_state: State, ) -> HashSet<(TokenId, State)> { let mut res = HashSet::new(); - for (vocab_item, token_transition_keys) in - vocabulary.iter().zip(vocabulary_transition_keys.iter()) - { - let token_ids: Vec = vocab_item.1.clone(); - + for (token, token_ids) in vocabulary.iter() { + let token_transition_keys = &vocabulary_transition_keys[token]; let state_seq = walk_fsm( fsm_transitions, fsm_initial, @@ -66,7 +63,7 @@ pub fn state_scan_tokens( continue; } - for &token_id in &token_ids { + for &token_id in token_ids { res.insert((token_id, *state_seq.last().unwrap())); } } @@ -110,10 +107,10 @@ pub fn get_token_transition_keys( pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, - vocabulary: &[(String, Vec)], + vocabulary: &Vocabulary, frozen_tokens: &HashSet, -) -> Vec> { - let mut vocab_transition_keys: Vec> = Vec::new(); +) -> HashMap> { + let mut vocab_transition_keys = HashMap::new(); for item in vocabulary.iter() { let token_str = item.0.clone(); @@ -137,7 +134,7 @@ pub fn get_vocabulary_transition_keys( ); } - vocab_transition_keys.push(token_transition_keys); + vocab_transition_keys.insert(token_str, token_transition_keys); } vocab_transition_keys diff --git a/src/vocabulary.rs b/src/vocabulary.rs new file mode 100644 index 00000000..f9263390 --- /dev/null +++ b/src/vocabulary.rs @@ -0,0 +1,101 @@ +use crate::*; + +/// Vocabulary of an LLM. +/// +/// ## Examples +/// +/// ```rust +/// # use outlines_core::*; +/// # +/// let vocabulary = Vocabulary::new() +/// .insert(0, "blah") +/// .insert(1, "1a") +/// .insert(2, "2") +/// .insert(3, "0"); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct Vocabulary(HashMap>); + +impl Vocabulary { + /// Creates an empty vocabulary. + pub fn new() -> Vocabulary { + Vocabulary::default() + } +} + +impl Vocabulary { + /// Inserts a token to the vocabulary with the specified identifier. + pub fn insert(mut self, id: TokenId, token: impl Into) -> Vocabulary { + let token = token.into(); + self.0.entry(token).or_default().push(id); + self + } + + /// Extends the vocabulary with tokens and their identifiers. + pub fn extend, I: IntoIterator>( + mut self, + tokens_and_ids: impl IntoIterator, + ) -> Vocabulary { + for (token, ids) in tokens_and_ids.into_iter() { + let token = token.into(); + for id in ids { + self = self.insert(id, token.clone()); + } + } + self + } +} + +impl Deref for Vocabulary { + type Target = HashMap>; + + fn deref(&self) -> &HashMap> { + &self.0 + } +} + +impl FromIterator<(T, I)> for Vocabulary +where + T: Into, + I: IntoIterator, +{ + fn from_iter>(tokens_and_ids: A) -> Self { + Vocabulary::new().extend(tokens_and_ids) + } +} + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn insert() { + let vocabulary = Vocabulary::new() + .insert(0, "blah") + .insert(1, "1a") + .insert(2, "2") + .insert(3, "0"); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } + + #[test] + fn extend() { + let vocabulary = Vocabulary::new().extend([ + ("blah", vec![0]), + ("1a", vec![1]), + ("2", vec![2]), + ("0", vec![3]), + ]); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } +} diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index f3a2e651..6e18e477 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -434,17 +434,13 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( + token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, list(vocabulary.items()), frozenset(), ) - token_str_to_tranition_keys = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } # `a` and `b` both are workable, but `z` has distinct transition rules assert interegular_fsm.accepts("zaz") assert interegular_fsm.accepts("zbz") @@ -470,22 +466,17 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( + token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, list(vocabulary.items()), frozenset(), ) - token_str_trans_key_seq = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } - # verify initial state valid only for "ab" and "ac" using transition key seq token_acceptance = {"ab": True, "ac": True, "az": False} for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_trans_key_seq[token] + token_trans_key_seq = token_str_to_tranition_keys[token] state_seq = _walk_fsm( regex_fsm.fsm_info.transitions, regex_fsm.fsm_info.initial,