diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py index 00169c72..834b5880 100644 --- a/python/outlines_core/fsm/regex.py +++ b/python/outlines_core/fsm/regex.py @@ -4,7 +4,6 @@ TYPE_CHECKING, Dict, FrozenSet, - Generator, Iterable, List, Optional, @@ -18,7 +17,6 @@ from interegular.fsm import ( FSM, Alphabet, - OblivionError, State, TransitionKey, _AnythingElseCls, @@ -270,17 +268,6 @@ def create_seq_transitions( ) -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: """Construct an equivalent FSM with deterministic state labels.""" old_to_new_trans_keys = { @@ -355,201 +342,6 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: return new_fsm, old_to_new_states -def walk_fsm( - fsm: BetterFSM, - token_transition_keys: Sequence[int], - start_state: int, - full_match: bool = True, -) -> List[int]: - fsm_finals = fsm.finals - - state = start_state - accepted_states: List[int] = [] - last_final_idx: int = 0 - - fsm_transitions = fsm.flat_transition_map - - # Iterate over token transition key sequence. The transition key - # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_transition_keys): - new_state = fsm_transitions.get((state, trans_key)) - - if new_state is None: - if not full_match and last_final_idx > 0: - return accepted_states[:last_final_idx] - - return [] - - state = new_state - - if state in fsm_finals: - last_final_idx = i + 1 - - accepted_states.append(state) - - if full_match and last_final_idx - 1 != i: - return [] - - return accepted_states - - -def fsm_union( - fsms: Sequence[FSM], -) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: - """Construct an FSM representing the union of the FSMs in `fsms`. - - This is an updated version of `interegular.fsm.FSM.union` made to return an - extra map of component FSMs to the sets of state transitions that - correspond to them in the new FSM. - - """ - - alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) - - indexed_fsms = tuple(enumerate(fsms)) - - initial = {i: fsm.initial for (i, fsm) in indexed_fsms} - - # Dedicated function accepting a "superset" and returning the next - # "superset" obtained by following this transition in the new FSM - def follow(current_state, new_transition: int): - next = {} - for i, f in indexed_fsms: - old_transition = new_to_old[i][new_transition] - if ( - i in current_state - and current_state[i] in f.map - and old_transition in f.map[current_state[i]] - ): - next[i] = f.map[current_state[i]][old_transition] - if not next: - raise OblivionError - return next - - states = [initial] - finals: Set[int] = set() - map: Dict[int, Dict[int, int]] = {} - - # Map component FSMs to their new state-to-state transitions, finals, and a - # map translating component FSM states to aggregate FSM states - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ] = {} - - i = 0 - while i < len(states): - state = states[i] - - # Add to the finals of the aggregate FSM whenever we hit a final in a - # component FSM - if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): - finals.add(i) - - # Compute the map for this state - map[i] = {} - for transition in alphabet.by_transition: - try: - next = follow(state, transition) - except OblivionError: - # Reached an oblivion state; don't list it - continue - else: - try: - # TODO: Seems like this could--and should--be avoided - j = states.index(next) - except ValueError: - j = len(states) - states.append(next) - - map[i][transition] = j - - for fsm_id, fsm_state in next.items(): - ( - fsm_transitions, - fsm_finals, - fsm_old_to_new, - ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) - old_from = state[fsm_id] - old_to = fsm_state - fsm_old_to_new.setdefault(old_from, set()).add(i) - fsm_old_to_new.setdefault(old_to, set()).add(j) - fsm_transitions.add((i, j)) - if fsm_state in fsms[fsm_id].finals: - fsm_finals.add(j) - - i += 1 - - fsm = FSM( - alphabet=alphabet, - states=range(len(states)), - initial=0, - finals=finals, - map=map, - __no_validation__=True, - ) - - fsm, old_to_new_states = make_deterministic_fsm(fsm) - _fsms_to_trans_finals = { - fsm_id: ( - {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, - {old_to_new_states[s] for s in finals}, - { - old_state: {old_to_new_states[new_state] for new_state in new_states} - for old_state, new_states in old_to_new.items() - }, - ) - for fsm_id, (transitions, finals, old_to_new) in sorted( - fsms_to_trans_finals.items(), key=lambda x: x[0] - ) - } - - return ( - fsm, - _fsms_to_trans_finals, - ) - - -def get_sub_fsms_from_seq( - state_seq: Sequence[int], - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ], -) -> Generator[Tuple[int, bool, bool], None, None]: - """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. - - Parameters - ---------- - state_seq - A state sequence. - fsms_to_trans_finals - A map from FSM indices to tuples containing sets of their state transitions - and sets of the final/accept states. - - Returns - ------- - A generator returning tuples containing each sub-FSM index (in the order - they were union-ed to construct `fsm`) and booleans indicating whether or - not there is another valid transition from the last state in the sequence - for the associated sub-FSM (i.e. if the FSM can continue - accepting/matching) and whether or not the sequence ends in a final state - of the sub-FSM. - """ - state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) - last_fsm_state = state_seq[-1] - yield from ( - ( - # The sub-FMS index - fsm_idx, - # Is there another possible transition in this sub-FSM? - any(last_fsm_state == from_s for (from_s, to_s) in transitions), - # Is this sub-FSM in a final state? - state_seq[-1] in finals, - ) - for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() - if state_seq_transitions.issubset(transitions) - ) - - re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") # The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*" diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 7b0018bb..b711ba48 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -2,18 +2,16 @@ import numpy as np import pytest from outlines_core.fsm.regex import ( + BetterAlphabet, + BetterFSM, _walk_fsm, create_fsm_index_end_to_end, create_fsm_index_tokenizer, - fsm_union, - get_sub_fsms_from_seq, get_token_transition_keys, get_vocabulary_transition_keys, - make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, reduced_vocabulary, - walk_fsm, ) from outlines_core.integrations.utils import adapt_tokenizer from outlines_core.models.transformers import TransformerTokenizer @@ -40,20 +38,6 @@ def token_str_to_trans_key(fsm, input_string): ) -def walk_fsm_from_token_str( - fsm, - input_string: str, - start_state: int, - full_match: bool = True, -): - return walk_fsm( - fsm, - token_str_to_trans_key(fsm, input_string), - start_state, - full_match, - ) - - def walk_fsm_from_token_str_rust( fsm, input_string: str, @@ -70,63 +54,82 @@ def walk_fsm_from_token_str_rust( ) -@pytest.mark.parametrize( - "function", - [ - walk_fsm_from_token_str, - walk_fsm_from_token_str_rust, - ], -) -def test_walk_fsm(function): +def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: + new_fsm = make_byte_level_fsm(fsm, keep_utf8) + return BetterFSM( + alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), + states=new_fsm.states, + initial=new_fsm.initial, + finals=new_fsm.finals, + map=new_fsm.map, + ) + + +def test_walk_fsm(): regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - res = tuple(function(regex_fsm, "0", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "0", regex_fsm.initial, full_match=True) + ) assert res == (1,) - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=False)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "00", regex_fsm.initial, full_match=False + ) + ) assert res == (1,) - res = tuple(function(regex_fsm, "!", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "!", regex_fsm.initial, full_match=True) + ) assert res == tuple() - res = tuple(function(regex_fsm, "00", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "00", regex_fsm.initial, full_match=True + ) + ) assert res == tuple() # This should fail, because state `1` reads nothing - res = tuple(function(regex_fsm, "0", 1, full_match=True)) + res = tuple(walk_fsm_from_token_str_rust(regex_fsm, "0", 1, full_match=True)) assert res == tuple() regex_pattern = interegular.parse_pattern("0|[1-9][2-9]+") regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust(regex_fsm, "1", regex_fsm.initial, full_match=True) + ) assert res == tuple() - res = tuple(function(regex_fsm, "1", regex_fsm.initial, full_match=False)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "1", regex_fsm.initial, full_match=False + ) + ) assert res == (2,) - res = tuple(function(regex_fsm, "12", regex_fsm.initial, full_match=True)) + res = tuple( + walk_fsm_from_token_str_rust( + regex_fsm, "12", regex_fsm.initial, full_match=True + ) + ) assert res == (2, 3) pattern = interegular.parse_pattern(r"(?:[^\W\d]\w*|[\t \x0c]+)") fsm, _ = make_deterministic_fsm(pattern.to_fsm().reduce()) - res = tuple(function(fsm, "x ", fsm.initial, full_match=False)) + res = tuple(walk_fsm_from_token_str_rust(fsm, "x ", fsm.initial, full_match=False)) assert res == (2,) start_state = list(fsm.finals)[0] - res = tuple(function(fsm, "!", start_state, full_match=False)) + res = tuple(walk_fsm_from_token_str_rust(fsm, "!", start_state, full_match=False)) assert res == tuple() -@pytest.mark.parametrize( - "function", - [ - walk_fsm_from_token_str, - walk_fsm_from_token_str_rust, - ], -) @pytest.mark.parametrize( "transform", [ @@ -134,20 +137,20 @@ def test_walk_fsm(function): to_bytes, ], ) -def test_walk_fsm_multi_bytes(function, transform): +def test_walk_fsm_multi_bytes(transform): regex_pattern = interegular.parse_pattern("😂|[😇-😍][😈-😍]*") str_regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) regex_fsm = make_byte_level_better_fsm(str_regex_fsm, keep_utf8=True) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("😂")), regex_fsm.initial, full_match=True ) ) assert res[-1:] == (1,) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("😂😂")), regex_fsm.initial, @@ -157,14 +160,14 @@ def test_walk_fsm_multi_bytes(function, transform): assert res[-1:] == (1,) res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("!")), regex_fsm.initial, full_match=True ) ) assert res == tuple() res = tuple( - function( + walk_fsm_from_token_str_rust( regex_fsm, merge_symbols(transform("😂😂")), regex_fsm.initial, @@ -174,161 +177,6 @@ def test_walk_fsm_multi_bytes(function, transform): assert res == tuple() -def test_get_sub_fsms_from_seq(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - - match_pattern = interegular.parse_pattern("match") - match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) - - peq_pattern = interegular.parse_pattern(r"\+=") - peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) - - plus_pattern = interegular.parse_pattern(r"\+") - plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) - - fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] - - fsm, fsms_to_trans_finals = fsm_union(fsms) - - assert fsms_to_trans_finals == { - 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), - 1: ( - {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, - {8}, - {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, - ), - 2: ( - { - (0, 2), - (0, 3), - (0, 4), - (2, 2), - (3, 2), - (3, 9), - (4, 2), - (4, 5), - (5, 2), - (5, 6), - (6, 2), - (6, 7), - (7, 2), - (7, 8), - (8, 2), - (9, 2), - (9, 10), - (10, 2), - }, - {2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, - ), - 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), - 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), - } - - assert not fsm.accepts("1a") - assert fsm.accepts("a1") - assert fsm.accepts("def") - assert fsm.accepts("match") - assert fsm.accepts("+=") - assert fsm.accepts("+") - - state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) - state_seq.insert(0, fsm.fsm_info.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (2, True, True)] - - # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) - def_state_seq.insert(0, fsm.fsm_info.initial) - - def_old_to_new_states = fsms_to_trans_finals[0][2] - assert all( - new_state in def_old_to_new_states[old_state] - for old_state, new_state in zip(def_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) - name_state_seq.insert(0, fsm.initial) - - name_old_to_new_states = fsms_to_trans_finals[2][2] - assert all( - new_state in name_old_to_new_states[old_state] - for old_state, new_state in zip(name_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, False, True), (2, True, True)] - - match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) - match_state_seq.insert(0, fsm.initial) - - match_old_to_new_states = fsms_to_trans_finals[1][2] - assert all( - new_state in match_old_to_new_states[old_state] - for old_state, new_state in zip(match_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, True, False), (4, False, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, False, True)] - - # Test some overlapping patterns - join_fsms = [ - interegular.parse_pattern(r"JOIN").to_fsm().reduce(), - interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), - ] - fsm, fsms_to_trans_finals = fsm_union(join_fsms) - - # Matching "OI" - state_seq = [1, 2, 3] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (1, True, False)] - - # Matching "N" - state_seq = [3, 4] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (1, True, False)] - - # Matching " " - state_seq = [4, 5] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, True, False)] - - def test_create_fsm_index_end_to_end(): regex_str = "0|[1-9][0-9]*"