Skip to content

Commit

Permalink
Return hash map for vocabulary transition keys
Browse files Browse the repository at this point in the history
This was required because vocabulary was no longer ordered. So returning a vector was causing ordering issues.
  • Loading branch information
umut-sahin authored and brandonwillard committed Sep 24, 2024
1 parent 4313ec6 commit f3fb1fb
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 26 deletions.
4 changes: 2 additions & 2 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, List[int]]],
vocabulary_transition_keys: List[List[int]],
vocabulary_transition_keys: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]: ...
def get_token_transition_keys(
Expand All @@ -46,7 +46,7 @@ def get_vocabulary_transition_keys(
alphabet_anything_value: int,
vocabulary: List[Tuple[str, List[int]]],
frozen_tokens: Set[str],
) -> List[List[int]]: ...
) -> Dict[str, List[int]]: ...
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, List[int]]],
Expand Down
4 changes: 2 additions & 2 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub fn state_scan_tokens_py(
fsm_initial: State,
fsm_finals: HashSet<State>,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary_transition_keys: Vec<Vec<TransitionKey>>,
vocabulary_transition_keys: HashMap<String, Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Expand Down Expand Up @@ -131,7 +131,7 @@ pub fn get_vocabulary_transition_keys_py(
alphabet_anything_value: TransitionKey,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<TransitionKey>>> {
) -> PyResult<HashMap<String, Vec<TransitionKey>>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
Expand Down
17 changes: 7 additions & 10 deletions src/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,13 @@ pub fn state_scan_tokens(
fsm_initial: State,
fsm_finals: &HashSet<State>,
vocabulary: &Vocabulary,
vocabulary_transition_keys: &[Vec<TransitionKey>],
vocabulary_transition_keys: &HashMap<Token, Vec<TransitionKey>>,
start_state: State,
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::new();

for (vocab_item, token_transition_keys) in
vocabulary.iter().zip(vocabulary_transition_keys.iter())
{
let token_ids: Vec<TokenId> = vocab_item.1.clone();

for (token, token_ids) in vocabulary.iter() {
let token_transition_keys = &vocabulary_transition_keys[token];
let state_seq = walk_fsm(
fsm_transitions,
fsm_initial,
Expand All @@ -66,7 +63,7 @@ pub fn state_scan_tokens(
continue;
}

for &token_id in &token_ids {
for &token_id in token_ids {
res.insert((token_id, *state_seq.last().unwrap()));
}
}
Expand Down Expand Up @@ -112,8 +109,8 @@ pub fn get_vocabulary_transition_keys(
alphabet_anything_value: TransitionKey,
vocabulary: &Vocabulary,
frozen_tokens: &HashSet<String>,
) -> Vec<Vec<TransitionKey>> {
let mut vocab_transition_keys: Vec<Vec<TransitionKey>> = Vec::new();
) -> HashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = HashMap::new();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand All @@ -137,7 +134,7 @@ pub fn get_vocabulary_transition_keys(
);
}

vocab_transition_keys.push(token_transition_keys);
vocab_transition_keys.insert(token_str, token_transition_keys);
}

vocab_transition_keys
Expand Down
15 changes: 3 additions & 12 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,13 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_vocabulary_transition_keys(
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
frozenset(),
)

token_str_to_tranition_keys = {
token_str: trans_key_seq
for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys)
}
# `a` and `b` both are workable, but `z` has distinct transition rules
assert interegular_fsm.accepts("zaz")
assert interegular_fsm.accepts("zbz")
Expand All @@ -470,22 +466,17 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_vocabulary_transition_keys(
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
frozenset(),
)

token_str_trans_key_seq = {
token_str: trans_key_seq
for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys)
}

# verify initial state valid only for "ab" and "ac" using transition key seq
token_acceptance = {"ab": True, "ac": True, "az": False}
for token, should_accept in token_acceptance.items():
token_trans_key_seq = token_str_trans_key_seq[token]
token_trans_key_seq = token_str_to_tranition_keys[token]
state_seq = _walk_fsm(
regex_fsm.fsm_info.transitions,
regex_fsm.fsm_info.initial,
Expand Down

0 comments on commit f3fb1fb

Please sign in to comment.