Skip to content

Commit

Permalink
utilities for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum authored and brandonwillard committed Oct 1, 2024
1 parent 58914c9 commit 063ea6b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 17 deletions.
33 changes: 26 additions & 7 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ pub struct InteregularFSMInfo {
states: HashSet<u32>,
#[pyo3(get)]
map: HashMap<u32, HashMap<u32, u32>>,
#[pyo3(get)]
symbol_mapping: HashMap<char, usize>,
#[pyo3(get)]
by_transition: HashMap<usize, Vec<char>>,
}

use crate::interegular::fsm::Alphabet;
Expand All @@ -494,7 +498,7 @@ use crate::interegular::patterns::Flag;

#[pyfunction(name = "parse_pattern_to_fsm")]
#[pyo3(text_signature = "(pattern: &str)")]
pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult<InteregularFSMInfo> {
pub fn parse_pattern_to_fsm_internal(pattern: &str) -> PyResult<InteregularFSMInfo> {
let regex_element =
parse_pattern(pattern).map_err(|_| PyValueError::new_err("Invalid pattern"))?;

Expand All @@ -512,14 +516,15 @@ pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult<Inte
my_new_symbol_mapping.insert('\0', TransitionKey::Symbol(0)); // add \0 as the anything symbol at 0

let mut counter = 1;
for (symbol, transition_key) in patterns_alphabet.symbol_mapping.iter() {
let transition_key_inc_by_one = TransitionKey::Symbol(counter);
my_new_symbol_mapping.insert(symbol.clone(), transition_key_inc_by_one);
counter += 1;
for (symbol, _) in patterns_alphabet.symbol_mapping.iter() {
if *symbol != '\0' {
my_new_symbol_mapping.insert(*symbol, TransitionKey::Symbol(counter));
counter += 1;
}
}
let alphabet = Alphabet::new(my_new_symbol_mapping);

let fsm_info = regex_element.to_fsm(Some(alphabet), prefix_postfix, flags);
let alphabet = Alphabet::new(my_new_symbol_mapping);
let fsm_info = regex_element.to_fsm(Some(alphabet.clone()), prefix_postfix, flags);

// convert into u32 for python
let map: HashMap<u32, HashMap<u32, u32>> = fsm_info
Expand All @@ -535,11 +540,25 @@ pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult<Inte
})
.collect();

let python_symbol_mapping: HashMap<char, usize> = alphabet
.symbol_mapping
.iter()
.map(|(k, v)| (*k, (*v).into()))
.collect();

let python_by_transition: HashMap<usize, Vec<char>> = alphabet
.by_transition
.iter()
.map(|(k, v)| (usize::from(*k), v.iter().map(|&c| c).collect()))
.collect();

Ok(InteregularFSMInfo {
initial: fsm_info.initial.into(),
finals: fsm_info.finals.iter().map(|f| (*f).into()).collect(),
states: fsm_info.states.iter().map(|s| (*s).into()).collect(),
map,
symbol_mapping: python_symbol_mapping,
by_transition: python_by_transition,
})
}

Expand Down
92 changes: 82 additions & 10 deletions tests/interegular/test_parse_pattern_to_fsm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,66 @@
# TODO: THIS IS A WORK IN PROGRESS AND WILL BE COMPLETELY REFACTORED BEFORE MERGING
from interegular.fsm import anything_else
from outlines_core.fsm.regex import parse_pattern_to_fsm

import interegular


class InteregularFSMInfo:
def __init__(self, initial, finals, states, map, symbol_mapping, by_transition):
self.initial = initial
self.finals = finals
self.states = states
self.map = map
self.symbol_mapping = symbol_mapping
self.by_transition = by_transition


def map_states_with_symbols(state_map, symbol_mapping):
inv_symbol_mapping = {v: k for k, v in symbol_mapping.items()}

mapped_states = {}
for state, transitions in state_map.items():
mapped_transitions = {}
for symbol, next_state in transitions.items():
mapped_symbol = inv_symbol_mapping.get(symbol, symbol)
mapped_transitions[mapped_symbol] = next_state
mapped_states[state] = mapped_transitions

return mapped_states


def make_fsm_comparable(fsm):
# Create a new symbol mapping
new_symbol_mapping = {}
for symbol, value in fsm.symbol_mapping.items():
if symbol == "\x00":
new_symbol_mapping[anything_else] = value
else:
new_symbol_mapping[symbol] = value

# Create a new map
new_map = {}
for state, transitions in fsm.map.items():
new_transitions = {}
for symbol, next_state in transitions.items():
if symbol == b"\x00":
new_transitions[anything_else] = next_state
else:
new_transitions[symbol] = next_state
new_map[state] = new_transitions

new_fsm = InteregularFSMInfo(
states=fsm.states,
initial=fsm.initial,
finals=fsm.finals,
map=new_map,
symbol_mapping=new_symbol_mapping,
by_transition=fsm.by_transition,
)

return new_fsm


def compare_sets(set1, set2):
# ensure that the sets are equal
return frozenset(set1) == frozenset(set2)
Expand All @@ -18,6 +75,7 @@ def sort_map(map):

def test_parse_pattern_to_fsm(pattern):
fsm = parse_pattern_to_fsm(pattern)
fsm = make_fsm_comparable(fsm)

ref_pattern = interegular.parse_pattern(pattern)

Expand Down Expand Up @@ -47,19 +105,28 @@ def test_parse_pattern_to_fsm(pattern):
# assert fsm.finals == ref_fsm.finals
# assert fsm.map == ref_fsm.map

# make maps deterministic (sort by key)
fsm_map = sort_map(fsm.map)
ref_map = sort_map(ref_fsm.map)

equal_states = frozenset(fsm.states) == frozenset(ref_fsm.states)
equal_initial = fsm.initial == ref_fsm.initial
equal_finals = frozenset(fsm.finals) == frozenset(ref_fsm.finals)
# equal_map = fsm.map == ref_fsm.map
equal_map = map_states_with_symbols(
fsm.map, fsm.symbol_mapping
) == map_states_with_symbols(ref_fsm.map, ref_fsm.alphabet._symbol_mapping)

print()
if equal_states and equal_initial and equal_finals: # and equal_map:
if equal_states and equal_initial and equal_finals and equal_map:
print(f"✅ Test passed for pattern: {pattern}")
else:
print(f"❌ Test failed for pattern: {pattern}")

print("_symbol_mapping\n", ref_fsm.alphabet._symbol_mapping)
print("by_transition\n", ref_fsm.alphabet.by_transition)
print("fsm: symbol_mapping\n", fsm.symbol_mapping)
print("fsm: by_transition\n", fsm.by_transition)

print("ref: symbol_mapping\n", ref_fsm.alphabet._symbol_mapping)
print("ref: by_transition\n", ref_fsm.alphabet.by_transition)

print("States")
print(f" fsm: {frozenset(fsm.states)}")
Expand All @@ -75,24 +142,29 @@ def test_parse_pattern_to_fsm(pattern):

print("Map")

# make maps deterministic (sort by key)
fsm_map = sort_map(fsm.map)
ref_map = sort_map(ref_fsm.map)

print(f" fsm: {fsm_map}")
print(f" ref: {ref_map}")

print("Map with symbols")
fsm_map_with_symbols = map_states_with_symbols(fsm_map, fsm.symbol_mapping)
print(f" fsm: {sort_map(fsm_map_with_symbols)}")

ref_map_with_symbols = map_states_with_symbols(
ref_map, ref_fsm.alphabet._symbol_mapping
)
print(f" ref: {sort_map(ref_map_with_symbols)}")

return True


# TODO: remove if not needed
# tests copied so they can be run as a standalone script
if __name__ == "__main__":
test_cases = [
"a",
# "a",
# "ab",
# "a|b",
# "[ab]",
"[ab]",
# TODO: long simple patterns (should work)
# "aaaaa",
# "davidholtz",
Expand Down

0 comments on commit 063ea6b

Please sign in to comment.