diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py index 00169c72..9d4a03fe 100644 --- a/python/outlines_core/fsm/regex.py +++ b/python/outlines_core/fsm/regex.py @@ -4,7 +4,6 @@ TYPE_CHECKING, Dict, FrozenSet, - Generator, Iterable, List, Optional, @@ -18,7 +17,6 @@ from interegular.fsm import ( FSM, Alphabet, - OblivionError, State, TransitionKey, _AnythingElseCls, @@ -270,17 +268,6 @@ def create_seq_transitions( ) -def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: - new_fsm = make_byte_level_fsm(fsm, keep_utf8) - return BetterFSM( - alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), - states=new_fsm.states, - initial=new_fsm.initial, - finals=new_fsm.finals, - map=new_fsm.map, - ) - - def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: """Construct an equivalent FSM with deterministic state labels.""" old_to_new_trans_keys = { @@ -393,163 +380,6 @@ def walk_fsm( return accepted_states -def fsm_union( - fsms: Sequence[FSM], -) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]: - """Construct an FSM representing the union of the FSMs in `fsms`. - - This is an updated version of `interegular.fsm.FSM.union` made to return an - extra map of component FSMs to the sets of state transitions that - correspond to them in the new FSM. - - """ - - alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms]) - - indexed_fsms = tuple(enumerate(fsms)) - - initial = {i: fsm.initial for (i, fsm) in indexed_fsms} - - # Dedicated function accepting a "superset" and returning the next - # "superset" obtained by following this transition in the new FSM - def follow(current_state, new_transition: int): - next = {} - for i, f in indexed_fsms: - old_transition = new_to_old[i][new_transition] - if ( - i in current_state - and current_state[i] in f.map - and old_transition in f.map[current_state[i]] - ): - next[i] = f.map[current_state[i]][old_transition] - if not next: - raise OblivionError - return next - - states = [initial] - finals: Set[int] = set() - map: Dict[int, Dict[int, int]] = {} - - # Map component FSMs to their new state-to-state transitions, finals, and a - # map translating component FSM states to aggregate FSM states - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ] = {} - - i = 0 - while i < len(states): - state = states[i] - - # Add to the finals of the aggregate FSM whenever we hit a final in a - # component FSM - if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms): - finals.add(i) - - # Compute the map for this state - map[i] = {} - for transition in alphabet.by_transition: - try: - next = follow(state, transition) - except OblivionError: - # Reached an oblivion state; don't list it - continue - else: - try: - # TODO: Seems like this could--and should--be avoided - j = states.index(next) - except ValueError: - j = len(states) - states.append(next) - - map[i][transition] = j - - for fsm_id, fsm_state in next.items(): - ( - fsm_transitions, - fsm_finals, - fsm_old_to_new, - ) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {})) - old_from = state[fsm_id] - old_to = fsm_state - fsm_old_to_new.setdefault(old_from, set()).add(i) - fsm_old_to_new.setdefault(old_to, set()).add(j) - fsm_transitions.add((i, j)) - if fsm_state in fsms[fsm_id].finals: - fsm_finals.add(j) - - i += 1 - - fsm = FSM( - alphabet=alphabet, - states=range(len(states)), - initial=0, - finals=finals, - map=map, - __no_validation__=True, - ) - - fsm, old_to_new_states = make_deterministic_fsm(fsm) - _fsms_to_trans_finals = { - fsm_id: ( - {(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions}, - {old_to_new_states[s] for s in finals}, - { - old_state: {old_to_new_states[new_state] for new_state in new_states} - for old_state, new_states in old_to_new.items() - }, - ) - for fsm_id, (transitions, finals, old_to_new) in sorted( - fsms_to_trans_finals.items(), key=lambda x: x[0] - ) - } - - return ( - fsm, - _fsms_to_trans_finals, - ) - - -def get_sub_fsms_from_seq( - state_seq: Sequence[int], - fsms_to_trans_finals: Dict[ - int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]] - ], -) -> Generator[Tuple[int, bool, bool], None, None]: - """Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`. - - Parameters - ---------- - state_seq - A state sequence. - fsms_to_trans_finals - A map from FSM indices to tuples containing sets of their state transitions - and sets of the final/accept states. - - Returns - ------- - A generator returning tuples containing each sub-FSM index (in the order - they were union-ed to construct `fsm`) and booleans indicating whether or - not there is another valid transition from the last state in the sequence - for the associated sub-FSM (i.e. if the FSM can continue - accepting/matching) and whether or not the sequence ends in a final state - of the sub-FSM. - """ - state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:])) - last_fsm_state = state_seq[-1] - yield from ( - ( - # The sub-FMS index - fsm_idx, - # Is there another possible transition in this sub-FSM? - any(last_fsm_state == from_s for (from_s, to_s) in transitions), - # Is this sub-FSM in a final state? - state_seq[-1] in finals, - ) - for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items() - if state_seq_transitions.issubset(transitions) - ) - - re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") # The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*" diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 7b0018bb..7c9ac8d6 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -2,14 +2,13 @@ import numpy as np import pytest from outlines_core.fsm.regex import ( + BetterAlphabet, + BetterFSM, _walk_fsm, create_fsm_index_end_to_end, create_fsm_index_tokenizer, - fsm_union, - get_sub_fsms_from_seq, get_token_transition_keys, get_vocabulary_transition_keys, - make_byte_level_better_fsm, make_byte_level_fsm, make_deterministic_fsm, reduced_vocabulary, @@ -70,6 +69,17 @@ def walk_fsm_from_token_str_rust( ) +def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM: + new_fsm = make_byte_level_fsm(fsm, keep_utf8) + return BetterFSM( + alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping), + states=new_fsm.states, + initial=new_fsm.initial, + finals=new_fsm.finals, + map=new_fsm.map, + ) + + @pytest.mark.parametrize( "function", [ @@ -174,161 +184,6 @@ def test_walk_fsm_multi_bytes(function, transform): assert res == tuple() -def test_get_sub_fsms_from_seq(): - name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") - name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) - - def_pattern = interegular.parse_pattern("def") - def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) - - match_pattern = interegular.parse_pattern("match") - match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce()) - - peq_pattern = interegular.parse_pattern(r"\+=") - peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce()) - - plus_pattern = interegular.parse_pattern(r"\+") - plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce()) - - fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm] - - fsm, fsms_to_trans_finals = fsm_union(fsms) - - assert fsms_to_trans_finals == { - 0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}), - 1: ( - {(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)}, - {8}, - {0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}}, - ), - 2: ( - { - (0, 2), - (0, 3), - (0, 4), - (2, 2), - (3, 2), - (3, 9), - (4, 2), - (4, 5), - (5, 2), - (5, 6), - (6, 2), - (6, 7), - (7, 2), - (7, 8), - (8, 2), - (9, 2), - (9, 10), - (10, 2), - }, - {2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}}, - ), - 3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}), - 4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}), - } - - assert not fsm.accepts("1a") - assert fsm.accepts("a1") - assert fsm.accepts("def") - assert fsm.accepts("match") - assert fsm.accepts("+=") - assert fsm.accepts("+") - - state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial) - state_seq.insert(0, fsm.fsm_info.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (2, True, True)] - - # Make sure the old-to-new state map is correct - def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial) - def_state_seq.insert(0, fsm.fsm_info.initial) - - def_old_to_new_states = fsms_to_trans_finals[0][2] - assert all( - new_state in def_old_to_new_states[old_state] - for old_state, new_state in zip(def_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial) - name_state_seq.insert(0, fsm.initial) - - name_old_to_new_states = fsms_to_trans_finals[2][2] - assert all( - new_state in name_old_to_new_states[old_state] - for old_state, new_state in zip(name_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, False, True), (2, True, True)] - - match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial) - match_state_seq.insert(0, fsm.initial) - - match_old_to_new_states = fsms_to_trans_finals[1][2] - assert all( - new_state in match_old_to_new_states[old_state] - for old_state, new_state in zip(match_state_seq, state_seq) - ) - - state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (2, True, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, True, False), (4, False, True)] - - state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial) - state_seq.insert(0, fsm.initial) - - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(3, False, True)] - - # Test some overlapping patterns - join_fsms = [ - interegular.parse_pattern(r"JOIN").to_fsm().reduce(), - interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(), - ] - fsm, fsms_to_trans_finals = fsm_union(join_fsms) - - # Matching "OI" - state_seq = [1, 2, 3] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, True, False), (1, True, False)] - - # Matching "N" - state_seq = [3, 4] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(0, False, True), (1, True, False)] - - # Matching " " - state_seq = [4, 5] - res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals)) - assert res == [(1, True, False)] - - def test_create_fsm_index_end_to_end(): regex_str = "0|[1-9][0-9]*"