From f3fb1fb4c87dd5b78f7c9d69ff8e5c3447e48fd0 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 24 Sep 2024 18:24:13 +0300 Subject: [PATCH] Return hash map for vocabulary transition keys This was required because vocabulary was no longer ordered. So returning a vector was causing ordering issues. --- python/outlines_core/fsm/outlines_core_rs.pyi | 4 ++-- src/python_bindings/mod.rs | 4 ++-- src/regex.rs | 17 +++++++---------- tests/fsm/test_regex.py | 15 +++------------ 4 files changed, 14 insertions(+), 26 deletions(-) diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index f970980b..4a913d40 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -33,7 +33,7 @@ def state_scan_tokens( fsm_initial: int, fsm_finals: Set[int], vocabulary: List[Tuple[str, List[int]]], - vocabulary_transition_keys: List[List[int]], + vocabulary_transition_keys: Dict[str, List[int]], start_state: int, ) -> Set[Tuple[int, int]]: ... def get_token_transition_keys( @@ -46,7 +46,7 @@ def get_vocabulary_transition_keys( alphabet_anything_value: int, vocabulary: List[Tuple[str, List[int]]], frozen_tokens: Set[str], -) -> List[List[int]]: ... +) -> Dict[str, List[int]]: ... def create_fsm_index_end_to_end( fsm_info: FSMInfo, vocabulary: List[Tuple[str, List[int]]], diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 34cbf4f3..5017b1de 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -94,7 +94,7 @@ pub fn state_scan_tokens_py( fsm_initial: State, fsm_finals: HashSet, vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, + vocabulary_transition_keys: HashMap>, start_state: State, ) -> PyResult> { let vocabulary = Vocabulary::from_iter(vocabulary); @@ -131,7 +131,7 @@ pub fn get_vocabulary_transition_keys_py( alphabet_anything_value: TransitionKey, vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, -) -> PyResult>> { +) -> PyResult>> { let vocabulary = Vocabulary::from_iter(vocabulary); Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, diff --git a/src/regex.rs b/src/regex.rs index c0eda767..aac6467c 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -43,16 +43,13 @@ pub fn state_scan_tokens( fsm_initial: State, fsm_finals: &HashSet, vocabulary: &Vocabulary, - vocabulary_transition_keys: &[Vec], + vocabulary_transition_keys: &HashMap>, start_state: State, ) -> HashSet<(TokenId, State)> { let mut res = HashSet::new(); - for (vocab_item, token_transition_keys) in - vocabulary.iter().zip(vocabulary_transition_keys.iter()) - { - let token_ids: Vec = vocab_item.1.clone(); - + for (token, token_ids) in vocabulary.iter() { + let token_transition_keys = &vocabulary_transition_keys[token]; let state_seq = walk_fsm( fsm_transitions, fsm_initial, @@ -66,7 +63,7 @@ pub fn state_scan_tokens( continue; } - for &token_id in &token_ids { + for &token_id in token_ids { res.insert((token_id, *state_seq.last().unwrap())); } } @@ -112,8 +109,8 @@ pub fn get_vocabulary_transition_keys( alphabet_anything_value: TransitionKey, vocabulary: &Vocabulary, frozen_tokens: &HashSet, -) -> Vec> { - let mut vocab_transition_keys: Vec> = Vec::new(); +) -> HashMap> { + let mut vocab_transition_keys = HashMap::new(); for item in vocabulary.iter() { let token_str = item.0.clone(); @@ -137,7 +134,7 @@ pub fn get_vocabulary_transition_keys( ); } - vocab_transition_keys.push(token_transition_keys); + vocab_transition_keys.insert(token_str, token_transition_keys); } vocab_transition_keys diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index f3a2e651..6e18e477 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -434,17 +434,13 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( + token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, list(vocabulary.items()), frozenset(), ) - token_str_to_tranition_keys = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } # `a` and `b` both are workable, but `z` has distinct transition rules assert interegular_fsm.accepts("zaz") assert interegular_fsm.accepts("zbz") @@ -470,22 +466,17 @@ def convert_token_to_string(self, token): interegular_fsm = regex_pattern.to_fsm().reduce() regex_fsm, _ = make_deterministic_fsm(interegular_fsm) vocabulary, _ = reduced_vocabulary(tokenizer) - token_trans_keys = get_vocabulary_transition_keys( + token_str_to_tranition_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, list(vocabulary.items()), frozenset(), ) - token_str_trans_key_seq = { - token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) - } - # verify initial state valid only for "ab" and "ac" using transition key seq token_acceptance = {"ab": True, "ac": True, "az": False} for token, should_accept in token_acceptance.items(): - token_trans_key_seq = token_str_trans_key_seq[token] + token_trans_key_seq = token_str_to_tranition_keys[token] state_seq = _walk_fsm( regex_fsm.fsm_info.transitions, regex_fsm.fsm_info.initial,