Skip to content

Commit

Permalink
Remove unused functions in outlines_core.fsm.regex
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 21, 2024
1 parent e66b111 commit 2262942
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 328 deletions.
170 changes: 0 additions & 170 deletions python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
TYPE_CHECKING,
Dict,
FrozenSet,
Generator,
Iterable,
List,
Optional,
Expand All @@ -18,7 +17,6 @@
from interegular.fsm import (
FSM,
Alphabet,
OblivionError,
State,
TransitionKey,
_AnythingElseCls,
Expand Down Expand Up @@ -270,17 +268,6 @@ def create_seq_transitions(
)


def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM:
new_fsm = make_byte_level_fsm(fsm, keep_utf8)
return BetterFSM(
alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping),
states=new_fsm.states,
initial=new_fsm.initial,
finals=new_fsm.finals,
map=new_fsm.map,
)


def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
"""Construct an equivalent FSM with deterministic state labels."""
old_to_new_trans_keys = {
Expand Down Expand Up @@ -393,163 +380,6 @@ def walk_fsm(
return accepted_states


def fsm_union(
fsms: Sequence[FSM],
) -> Tuple[FSM, Dict[int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]]]:
"""Construct an FSM representing the union of the FSMs in `fsms`.
This is an updated version of `interegular.fsm.FSM.union` made to return an
extra map of component FSMs to the sets of state transitions that
correspond to them in the new FSM.
"""

alphabet, new_to_old = Alphabet.union(*[fsm.alphabet for fsm in fsms])

indexed_fsms = tuple(enumerate(fsms))

initial = {i: fsm.initial for (i, fsm) in indexed_fsms}

# Dedicated function accepting a "superset" and returning the next
# "superset" obtained by following this transition in the new FSM
def follow(current_state, new_transition: int):
next = {}
for i, f in indexed_fsms:
old_transition = new_to_old[i][new_transition]
if (
i in current_state
and current_state[i] in f.map
and old_transition in f.map[current_state[i]]
):
next[i] = f.map[current_state[i]][old_transition]
if not next:
raise OblivionError
return next

states = [initial]
finals: Set[int] = set()
map: Dict[int, Dict[int, int]] = {}

# Map component FSMs to their new state-to-state transitions, finals, and a
# map translating component FSM states to aggregate FSM states
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
] = {}

i = 0
while i < len(states):
state = states[i]

# Add to the finals of the aggregate FSM whenever we hit a final in a
# component FSM
if any(state.get(j, -1) in fsm.finals for (j, fsm) in indexed_fsms):
finals.add(i)

# Compute the map for this state
map[i] = {}
for transition in alphabet.by_transition:
try:
next = follow(state, transition)
except OblivionError:
# Reached an oblivion state; don't list it
continue
else:
try:
# TODO: Seems like this could--and should--be avoided
j = states.index(next)
except ValueError:
j = len(states)
states.append(next)

map[i][transition] = j

for fsm_id, fsm_state in next.items():
(
fsm_transitions,
fsm_finals,
fsm_old_to_new,
) = fsms_to_trans_finals.setdefault(fsm_id, (set(), set(), {}))
old_from = state[fsm_id]
old_to = fsm_state
fsm_old_to_new.setdefault(old_from, set()).add(i)
fsm_old_to_new.setdefault(old_to, set()).add(j)
fsm_transitions.add((i, j))
if fsm_state in fsms[fsm_id].finals:
fsm_finals.add(j)

i += 1

fsm = FSM(
alphabet=alphabet,
states=range(len(states)),
initial=0,
finals=finals,
map=map,
__no_validation__=True,
)

fsm, old_to_new_states = make_deterministic_fsm(fsm)
_fsms_to_trans_finals = {
fsm_id: (
{(old_to_new_states[s1], old_to_new_states[s2]) for s1, s2 in transitions},
{old_to_new_states[s] for s in finals},
{
old_state: {old_to_new_states[new_state] for new_state in new_states}
for old_state, new_states in old_to_new.items()
},
)
for fsm_id, (transitions, finals, old_to_new) in sorted(
fsms_to_trans_finals.items(), key=lambda x: x[0]
)
}

return (
fsm,
_fsms_to_trans_finals,
)


def get_sub_fsms_from_seq(
state_seq: Sequence[int],
fsms_to_trans_finals: Dict[
int, Tuple[Set[Tuple[int, int]], Set[int], Dict[int, Set[int]]]
],
) -> Generator[Tuple[int, bool, bool], None, None]:
"""Get the indices of the sub-FSMs in `fsm` that could have matched the state sequence `state_seq`.
Parameters
----------
state_seq
A state sequence.
fsms_to_trans_finals
A map from FSM indices to tuples containing sets of their state transitions
and sets of the final/accept states.
Returns
-------
A generator returning tuples containing each sub-FSM index (in the order
they were union-ed to construct `fsm`) and booleans indicating whether or
not there is another valid transition from the last state in the sequence
for the associated sub-FSM (i.e. if the FSM can continue
accepting/matching) and whether or not the sequence ends in a final state
of the sub-FSM.
"""
state_seq_transitions = set(zip(state_seq[:-1], state_seq[1:]))
last_fsm_state = state_seq[-1]
yield from (
(
# The sub-FMS index
fsm_idx,
# Is there another possible transition in this sub-FSM?
any(last_fsm_state == from_s for (from_s, to_s) in transitions),
# Is this sub-FSM in a final state?
state_seq[-1] in finals,
)
for fsm_idx, (transitions, finals, _) in fsms_to_trans_finals.items()
if state_seq_transitions.issubset(transitions)
)


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")

# The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*"
Expand Down
171 changes: 13 additions & 158 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import numpy as np
import pytest
from outlines_core.fsm.regex import (
BetterAlphabet,
BetterFSM,
_walk_fsm,
create_fsm_index_end_to_end,
create_fsm_index_tokenizer,
fsm_union,
get_sub_fsms_from_seq,
get_token_transition_keys,
get_vocabulary_transition_keys,
make_byte_level_better_fsm,
make_byte_level_fsm,
make_deterministic_fsm,
reduced_vocabulary,
Expand Down Expand Up @@ -70,6 +69,17 @@ def walk_fsm_from_token_str_rust(
)


def make_byte_level_better_fsm(fsm: BetterFSM, keep_utf8=False) -> BetterFSM:
new_fsm = make_byte_level_fsm(fsm, keep_utf8)
return BetterFSM(
alphabet=BetterAlphabet(new_fsm.alphabet._symbol_mapping),
states=new_fsm.states,
initial=new_fsm.initial,
finals=new_fsm.finals,
map=new_fsm.map,
)


@pytest.mark.parametrize(
"function",
[
Expand Down Expand Up @@ -174,161 +184,6 @@ def test_walk_fsm_multi_bytes(function, transform):
assert res == tuple()


def test_get_sub_fsms_from_seq():
name_pattern = interegular.parse_pattern(r"[^\W\d]\w*")
name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce())

def_pattern = interegular.parse_pattern("def")
def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce())

match_pattern = interegular.parse_pattern("match")
match_fsm, _ = make_deterministic_fsm(match_pattern.to_fsm().reduce())

peq_pattern = interegular.parse_pattern(r"\+=")
peq_fsm, _ = make_deterministic_fsm(peq_pattern.to_fsm().reduce())

plus_pattern = interegular.parse_pattern(r"\+")
plus_fsm, _ = make_deterministic_fsm(plus_pattern.to_fsm().reduce())

fsms = [def_fsm, match_fsm, name_fsm, peq_fsm, plus_fsm]

fsm, fsms_to_trans_finals = fsm_union(fsms)

assert fsms_to_trans_finals == {
0: ({(0, 3), (3, 9), (9, 10)}, {10}, {0: {0}, 1: {3}, 2: {9}, 3: {10}}),
1: (
{(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)},
{8},
{0: {0}, 1: {4}, 2: {5}, 3: {6}, 4: {7}, 5: {8}},
),
2: (
{
(0, 2),
(0, 3),
(0, 4),
(2, 2),
(3, 2),
(3, 9),
(4, 2),
(4, 5),
(5, 2),
(5, 6),
(6, 2),
(6, 7),
(7, 2),
(7, 8),
(8, 2),
(9, 2),
(9, 10),
(10, 2),
},
{2, 3, 4, 5, 6, 7, 8, 9, 10},
{0: {0}, 1: {2, 3, 4, 5, 6, 7, 8, 9, 10}},
),
3: ({(0, 1), (1, 11)}, {11}, {0: {0}, 1: {1}, 2: {11}}),
4: ({(0, 1)}, {1}, {0: {0}, 1: {1}}),
}

assert not fsm.accepts("1a")
assert fsm.accepts("a1")
assert fsm.accepts("def")
assert fsm.accepts("match")
assert fsm.accepts("+=")
assert fsm.accepts("+")

state_seq = walk_fsm_from_token_str(fsm, "def", fsm.initial)
state_seq.insert(0, fsm.fsm_info.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, False, True), (2, True, True)]

# Make sure the old-to-new state map is correct
def_state_seq = walk_fsm_from_token_str(def_fsm, "def", fsm.initial)
def_state_seq.insert(0, fsm.fsm_info.initial)

def_old_to_new_states = fsms_to_trans_finals[0][2]
assert all(
new_state in def_old_to_new_states[old_state]
for old_state, new_state in zip(def_state_seq, state_seq)
)

state_seq = walk_fsm_from_token_str(fsm, "ef", fsm.initial)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(2, True, True)]

name_state_seq = walk_fsm_from_token_str(name_fsm, "ef", fsm.initial)
name_state_seq.insert(0, fsm.initial)

name_old_to_new_states = fsms_to_trans_finals[2][2]
assert all(
new_state in name_old_to_new_states[old_state]
for old_state, new_state in zip(name_state_seq, state_seq)
)

state_seq = walk_fsm_from_token_str(fsm, "match", fsm.initial)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(1, False, True), (2, True, True)]

match_state_seq = walk_fsm_from_token_str(match_fsm, "match", fsm.initial)
match_state_seq.insert(0, fsm.initial)

match_old_to_new_states = fsms_to_trans_finals[1][2]
assert all(
new_state in match_old_to_new_states[old_state]
for old_state, new_state in zip(match_state_seq, state_seq)
)

state_seq = walk_fsm_from_token_str(fsm, "defa", fsm.initial)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(2, True, True)]

state_seq = walk_fsm_from_token_str(fsm, "de", fsm.initial)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, True, False), (2, True, True)]

state_seq = walk_fsm_from_token_str(fsm, "+", fsm.initial, False)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(3, True, False), (4, False, True)]

state_seq = walk_fsm_from_token_str(fsm, "+=", fsm.initial)
state_seq.insert(0, fsm.initial)

res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(3, False, True)]

# Test some overlapping patterns
join_fsms = [
interegular.parse_pattern(r"JOIN").to_fsm().reduce(),
interegular.parse_pattern(r"JOIN LEFT").to_fsm().reduce(),
]
fsm, fsms_to_trans_finals = fsm_union(join_fsms)

# Matching "OI"
state_seq = [1, 2, 3]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, True, False), (1, True, False)]

# Matching "N"
state_seq = [3, 4]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(0, False, True), (1, True, False)]

# Matching " "
state_seq = [4, 5]
res = list(get_sub_fsms_from_seq(state_seq, fsms_to_trans_finals))
assert res == [(1, True, False)]


def test_create_fsm_index_end_to_end():
regex_str = "0|[1-9][0-9]*"

Expand Down

0 comments on commit 2262942

Please sign in to comment.