Skip to content

Commit

Permalink
Align very simple FSM creation and return FSM to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh authored and brandonwillard committed Sep 24, 2024
1 parent 729e922 commit f8d1b37
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 22 deletions.
1 change: 1 addition & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def create_fsm_index_end_to_end(
frozen_tokens: frozenset[str],
) -> Dict[int, Dict[int, int]]: ...
def parse_pattern(pattern: str) -> Any: ...
def parse_pattern_to_fsm(pattern: str) -> Any: ...

BOOLEAN: str
DATE: str
Expand Down
3 changes: 2 additions & 1 deletion python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
anything_else,
)

from .outlines_core_rs import ( # noqa: F401
from .outlines_core_rs import ( # noqa: F401; TODO: likely temporary; just to ensure that the fsm creation works
FSMInfo,
_walk_fsm,
create_fsm_index_end_to_end,
get_token_transition_keys,
get_vocabulary_transition_keys,
parse_pattern,
parse_pattern_to_fsm,
state_scan_tokens,
)

Expand Down
227 changes: 216 additions & 11 deletions src/interegular/fsm.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use core::panic;
use std::collections::{HashMap, HashSet, VecDeque};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::fmt::Debug;
use std::hash::Hash;
use std::iter::from_fn;

#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Ord, PartialOrd)]
pub enum TransitionKey {
Symbol(usize),
AnythingElse,
// TODO: avoid using AnythingElse in favor of Symbol(0) or char '\0'
// This is only due to the incomplete implementation of the conversion from char to TransitionKey
// AnythingElse,
}

impl From<usize> for TransitionKey {
Expand All @@ -20,7 +21,6 @@ impl From<TransitionKey> for usize {
fn from(c: TransitionKey) -> Self {
match c {
TransitionKey::Symbol(i) => i,
_ => panic!("Cannot convert `anything else` to usize"),
}
}
}
Expand All @@ -29,7 +29,6 @@ impl From<TransitionKey> for u32 {
fn from(c: TransitionKey) -> Self {
match c {
TransitionKey::Symbol(i) => i as u32,
_ => panic!("Cannot convert `anything else` to u32"),
}
}
}
Expand Down Expand Up @@ -69,7 +68,7 @@ impl<T: SymbolTrait> Alphabet<T> {
pub fn get(&self, item: &T) -> TransitionKey {
match self.symbol_mapping.get(item) {
Some(x) => *x,
None => TransitionKey::AnythingElse,
None => TransitionKey::Symbol(0),
}
}

Expand Down Expand Up @@ -102,7 +101,7 @@ impl<T: SymbolTrait> Alphabet<T> {
);
}

let mut keys_to_symbols = HashMap::new();
let mut keys_to_symbols = BTreeMap::new(); // btree keeps the order
for (symbol, keys) in symbol_to_keys {
keys_to_symbols
.entry(keys.clone())
Expand Down Expand Up @@ -136,6 +135,15 @@ impl<T: SymbolTrait> Alphabet<T> {
}
}

impl Default for Alphabet<char> {
fn default() -> Self {
let mut symbol_mapping = HashMap::new();
// only insert \0 for anything_else
symbol_mapping.insert('\0', TransitionKey::Symbol(0));
Alphabet::new(symbol_mapping)
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct Fsm<T: SymbolTrait> {
pub alphabet: Alphabet<T>,
Expand Down Expand Up @@ -339,7 +347,7 @@ impl<T: SymbolTrait> Fsm<T> {
while _current_i < last_index && fsms[_current_i].finals.contains(&current_substate) {
_current_i += 1;
current_substate = fsms[_current_i].initial;
result.insert((current_i, current_substate));
result.insert((_current_i.into(), current_substate));
}

result
Expand Down Expand Up @@ -587,7 +595,7 @@ where
F: Fn(&C) -> bool,
G: Fn(&C, &TransitionKey) -> Option<C>,
I: Clone + Eq + Hash + std::fmt::Debug,
C: IntoIterator<Item = I> + FromIterator<I> + Clone + PartialEq,
C: IntoIterator<Item = I> + FromIterator<I> + Clone + PartialEq + std::fmt::Debug,
{
let mut states = VecDeque::new();
states.push_back(initial);
Expand Down Expand Up @@ -639,6 +647,17 @@ where
mod tests {
use super::*;

#[test]
fn test_create_default_alphabet() {
let default_alphabet = Alphabet::<char>::default();
assert_eq!(default_alphabet.symbol_mapping.len(), 1);
assert_eq!(default_alphabet.by_transition.len(), 1);
assert_eq!(
default_alphabet.by_transition[&TransitionKey::Symbol(0)],
vec!['\0']
);
}

fn create_simple_fsm() -> Fsm<char> {
let mut symbol_mapping = HashMap::new();
symbol_mapping.insert('a', 0.into());
Expand Down Expand Up @@ -753,10 +772,196 @@ mod tests {

assert!(union.accepts(&['a']));
assert!(union.accepts(&['b']));
assert!(!union.accepts(&[' ']));
assert!(!union.accepts(&['a', 'a']));
}

#[test]
fn test_union_of_single_character_fsms() {
// Create alphabet for FSM1 ('a' and anything_else)
let mut symbol_mapping1 = HashMap::new();
symbol_mapping1.insert('\0', 0.into()); // '\0' represents anything_else
symbol_mapping1.insert('a', 1.into());
let alphabet1 = Alphabet::new(symbol_mapping1);

// Create alphabet for FSM2 ('b' and anything_else)
let mut symbol_mapping2 = HashMap::new();
symbol_mapping2.insert('\0', 0.into()); // '\0' represents anything_else
symbol_mapping2.insert('b', 1.into());
let alphabet2 = Alphabet::new(symbol_mapping2);

let fsm1 = Fsm::new(
alphabet1.clone(),
[0.into(), 1.into()].iter().copied().collect(),
0.into(),
[1.into()].iter().copied().collect(),
[
//
(0.into(), [(1.into(), 1.into())].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect(),
);

let fsm2 = Fsm::new(
alphabet2.clone(),
[0.into(), 1.into()].iter().copied().collect(),
0.into(),
[1.into()].iter().copied().collect(),
[
(0.into(), [(1.into(), 1.into())].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect(),
);

assert_eq!(
fsm1.map,
[
(0.into(), [(1.into(), 1.into()),].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect()
);

assert_eq!(
fsm2.map,
[
(0.into(), [(1.into(), 1.into()),].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect()
);

let union_fsm = Fsm::union(&[fsm1, fsm2]);

assert_eq!(union_fsm.alphabet.symbol_mapping.len(), 3);
assert_eq!(
union_fsm.states,
[0.into(), 1.into(), 2.into()].iter().copied().collect()
);
assert_eq!(union_fsm.initial, 0.into());
assert_eq!(
union_fsm.finals,
[2.into(), 1.into()].iter().copied().collect()
);

// compare states
assert_eq!(
union_fsm.states,
[0.into(), 1.into(), 2.into()].iter().copied().collect()
);

let expected_map: HashMap<TransitionKey, HashMap<TransitionKey, TransitionKey>> = [
(
0.into(),
[(1.into(), 1.into()), (2.into(), 2.into())]
.iter()
.copied()
.collect(),
),
(1.into(), [].iter().copied().collect()),
(2.into(), [].iter().copied().collect()),
]
.into();

assert_eq!(union_fsm.map.get(&2.into()), Some(&expected_map[&2.into()]));
assert_eq!(union_fsm.map.get(&1.into()), Some(&expected_map[&1.into()]));
}

#[test]
fn test_concatenate_of_single_character_fsms() {
// Create alphabet for FSM1 ('a' and anything_else)
let mut symbol_mapping1 = HashMap::new();
symbol_mapping1.insert('\0', 0.into()); // '\0' represents anything_else
symbol_mapping1.insert('a', 1.into());
let alphabet1 = Alphabet::new(symbol_mapping1);

// Create alphabet for FSM2 ('b' and anything_else)
let mut symbol_mapping2 = HashMap::new();
symbol_mapping2.insert('\0', 0.into()); // '\0' represents anything_else
symbol_mapping2.insert('b', 1.into());
let alphabet2 = Alphabet::new(symbol_mapping2);

// Create FSM for "a"
let fsm1 = Fsm::new(
alphabet1.clone(),
[0.into(), 1.into()].iter().copied().collect(),
0.into(),
[1.into()].iter().copied().collect(),
[
(0.into(), [(1.into(), 1.into())].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect(),
);

// Create FSM for "b"
let fsm2 = Fsm::new(
alphabet2.clone(),
[0.into(), 1.into()].iter().copied().collect(),
0.into(),
[1.into()].iter().copied().collect(),
[
(0.into(), [(1.into(), 1.into())].iter().copied().collect()),
(1.into(), [].iter().copied().collect()),
]
.iter()
.cloned()
.collect(),
);

let concat_fsm = Fsm::concatenate(&[fsm1, fsm2]);

let expected = Fsm {
alphabet: Alphabet {
symbol_mapping: HashMap::from([
('\0', TransitionKey::Symbol(0)),
('a', TransitionKey::Symbol(1)),
('b', TransitionKey::Symbol(2)),
]),
by_transition: HashMap::from([
(TransitionKey::Symbol(0), vec!['\0']),
(TransitionKey::Symbol(1), vec!['a']),
(TransitionKey::Symbol(2), vec!['b']),
]),
},
states: HashSet::from([
TransitionKey::Symbol(0),
TransitionKey::Symbol(1),
TransitionKey::Symbol(2),
]),
initial: TransitionKey::Symbol(0),
finals: HashSet::from([TransitionKey::Symbol(2)]),
map: HashMap::from([
(
TransitionKey::Symbol(0),
HashMap::from([(TransitionKey::Symbol(1), TransitionKey::Symbol(1))]),
),
(
TransitionKey::Symbol(1),
HashMap::from([(TransitionKey::Symbol(2), TransitionKey::Symbol(2))]),
),
(TransitionKey::Symbol(2), HashMap::new()),
]),
};

assert_eq!(concat_fsm.states, expected.states);
assert_eq!(concat_fsm.initial, expected.initial);
assert_eq!(concat_fsm.finals, expected.finals);
assert_eq!(concat_fsm.map.get(&2.into()), expected.map.get(&2.into()));
assert_eq!(concat_fsm.map.get(&1.into()), expected.map.get(&1.into()));
}

#[test]
fn test_intersection() {
let fsm1 = Fsm::new(
Expand Down
Loading

0 comments on commit f8d1b37

Please sign in to comment.