Skip to content

Commit

Permalink
Initial to_fsm logic for core regex elements
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh authored and brandonwillard committed Sep 24, 2024
1 parent af9221a commit 729e922
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 7 deletions.
3 changes: 2 additions & 1 deletion python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple

class FSMInfo:
initial: int
Expand Down Expand Up @@ -52,6 +52,7 @@ def create_fsm_index_end_to_end(
vocabulary: List[Tuple[str, List[int]]],
frozen_tokens: frozenset[str],
) -> Dict[int, Dict[int, int]]: ...
def parse_pattern(pattern: str) -> Any: ...

BOOLEAN: str
DATE: str
Expand Down
23 changes: 20 additions & 3 deletions src/interegular/fsm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@ impl From<TransitionKey> for usize {
}
}

impl From<TransitionKey> for u32 {
fn from(c: TransitionKey) -> Self {
match c {
TransitionKey::Symbol(i) => i as u32,
_ => panic!("Cannot convert `anything else` to u32"),
}
}
}

pub trait SymbolTrait: Eq + Hash + Clone + Debug + From<char> {}
impl<T: Eq + Hash + Clone + Debug + From<char>> SymbolTrait for T {}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Alphabet<T: SymbolTrait> {
pub symbol_mapping: HashMap<T, TransitionKey>,
pub by_transition: HashMap<TransitionKey, Vec<T>>,
Expand All @@ -49,6 +58,14 @@ impl<T: SymbolTrait> Alphabet<T> {
}
}

#[must_use]
pub fn empty() -> Self {
Alphabet {
symbol_mapping: HashMap::new(),
by_transition: HashMap::new(),
}
}

pub fn get(&self, item: &T) -> TransitionKey {
match self.symbol_mapping.get(item) {
Some(x) => *x,
Expand Down Expand Up @@ -119,9 +136,9 @@ impl<T: SymbolTrait> Alphabet<T> {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Fsm<T: SymbolTrait> {
alphabet: Alphabet<T>,
pub alphabet: Alphabet<T>,
pub states: HashSet<TransitionKey>,
pub initial: TransitionKey,
pub finals: HashSet<TransitionKey>,
Expand Down
168 changes: 165 additions & 3 deletions src/interegular/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::rc::Rc;
use std::vec;

use crate::interegular::fsm::SymbolTrait;
use crate::interegular::fsm::TransitionKey;
use crate::interegular::fsm::{Alphabet, Fsm};

const SPECIAL_CHARS_INNER: [&str; 2] = ["\\", "]"];
Expand Down Expand Up @@ -136,6 +137,34 @@ impl RegexElement {
flags: Option<HashSet<Flag>>,
) -> Fsm<char> {
match self {
RegexElement::Literal(c) => {
let alphabet = alphabet
.unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default()));
let prefix_postfix = prefix_postfix.unwrap_or_else(|| self.get_prefix_postfix());

let case_insensitive = flags
.clone()
.as_ref()
.map_or(false, |f| f.contains(&Flag::CaseInsensitive));

let mut mapping = HashMap::<_, HashMap<_, _>>::new();
let symbol = alphabet.get(c);

let mut m = std::collections::HashMap::new();
m.insert(symbol, TransitionKey::Symbol(1_usize));
mapping.insert(TransitionKey::Symbol(0_usize), m);

let states = (0..=1).map(std::convert::Into::into).collect();
let finals = (1..=1).map(std::convert::Into::into).collect();

Fsm::new(
alphabet,
states, // {0, 1}
0.into(),
finals, // {1}
mapping,
)
}
RegexElement::CharGroup { chars, inverted } => {
let alphabet = alphabet
.unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default()));
Expand All @@ -151,7 +180,41 @@ impl RegexElement {
.as_ref()
.map_or(false, |f| f.contains(&Flag::CaseInsensitive));

let mapping = HashMap::<_, HashMap<_, _>>::new();
let mut mapping = HashMap::<_, HashMap<_, _>>::new();

if *inverted {
let chars = chars.clone();
let alphabet = alphabet.clone();
let alphabet_set = alphabet
.clone()
.by_transition
.keys()
.copied()
.collect::<HashSet<_>>();

let char_as_usize = chars
.iter()
.map(|c| TransitionKey::Symbol(*c as usize))
.collect();
let diff = alphabet_set
.difference(&char_as_usize)
.copied()
.collect::<Vec<_>>();

let mut m = std::collections::HashMap::new();
for symbol in diff {
m.insert(symbol, TransitionKey::Symbol(1_usize));
}
mapping.insert(TransitionKey::Symbol(0_usize), m);
} else {
let chars = chars.clone();
for symbol in chars {
let mut m = std::collections::HashMap::new();
let symbol_value = alphabet.get(&symbol);
m.insert(symbol_value, TransitionKey::Symbol(1_usize));
mapping.insert(TransitionKey::Symbol(0_usize), m);
}
}

let states = (0..=1).map(std::convert::Into::into).collect();
let finals = (1..=1).map(std::convert::Into::into).collect();
Expand All @@ -164,7 +227,62 @@ impl RegexElement {
mapping,
)
}
// Implement other variants as needed
RegexElement::Repeated { element, min, max } => {
let unit = element.to_fsm(alphabet.clone(), None, flags.clone());
let alphabet = alphabet
.unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default()));
let mandatory = std::iter::repeat(unit.clone()).take(*min).fold(
Fsm::new(
// TODO: fix if alphabet is None
alphabet.clone(),
HashSet::new(),
0.into(),
HashSet::new(),
std::collections::HashMap::new(),
),
|acc, f| Fsm::concatenate(&[acc, f]),
);

let optional = if max.is_none() {
unit.star()
} else {
let mut optional = unit.clone();
optional.finals.insert(optional.initial);
optional = std::iter::repeat(optional.clone())
.take(max.unwrap() - min)
.fold(
Fsm::new(
alphabet.clone(),
HashSet::new(),
0.into(),
HashSet::new(),
std::collections::HashMap::new(),
),
|acc, f| Fsm::concatenate(&[acc, f]),
);

optional
};

Fsm::concatenate(&[mandatory, optional])
}
RegexElement::Concatenation(parts) => {
let mut current = vec![];
for part in parts {
current.push(part.to_fsm(alphabet.clone(), None, flags.clone()));
}

Fsm::concatenate(&current)
}
RegexElement::Alternation(options) => {
let mut current = vec![];
for option in options {
current.push(option.to_fsm(alphabet.clone(), None, flags.clone()));
}

Fsm::union(&current)
}
// throw on non implemented variants
_ => unimplemented!("FSM conversion not implemented for this variant"),
}
}
Expand All @@ -186,6 +304,25 @@ impl RegexElement {
Alphabet::from_groups(&[relevant, HashSet::from(['\0'.into()])])
}
RegexElement::Literal(c) => Alphabet::from_groups(&[HashSet::from([(*c).into()])]),
RegexElement::Repeated { element, .. } => element.get_alphabet(flags),
RegexElement::Alternation(options) => {
let mut alphabet = Alphabet::empty();
for option in options {
let alphabets = vec![alphabet, option.get_alphabet(flags)];
let (res, new_to_old) = Alphabet::union(alphabets.as_slice());
alphabet = res;
}
alphabet
}
RegexElement::Concatenation(parts) => {
let mut alphabet = Alphabet::empty();
for part in parts {
let alphabets = vec![alphabet, part.get_alphabet(flags)];
let (res, new_to_old) = Alphabet::union(alphabets.as_slice());
alphabet = res;
}
alphabet
}
_ => unimplemented!("Alphabet not implemented for this variant"),
}
}
Expand Down Expand Up @@ -685,7 +822,7 @@ mod tests {

#[test]
fn test_parse_pattern_simple() {
let pattern = "a";
let pattern: &str = "a";
let result = parse_pattern(pattern);
assert_eq!(
result,
Expand Down Expand Up @@ -960,4 +1097,29 @@ mod tests {
let result = parse_pattern(pattern);
assert!(result.is_err());
}
#[test]
fn test_parse_pattern_simple_to_fsm() {
let pattern: &str = "a";
let result = parse_pattern(pattern).unwrap();
let result = result.to_fsm(None, None, None);

let expected = Fsm {
alphabet: Alphabet {
symbol_mapping: HashMap::from([('a', TransitionKey::Symbol(0))]),
by_transition: HashMap::from([(TransitionKey::Symbol(0), vec!['a'])]),
},
states: HashSet::from([TransitionKey::Symbol(0), TransitionKey::Symbol(1)]),
initial: TransitionKey::Symbol(0),
finals: HashSet::from([TransitionKey::Symbol(1)]),
map: HashMap::from([
(
TransitionKey::Symbol(0),
HashMap::from([(TransitionKey::Symbol(0), TransitionKey::Symbol(1))]),
),
(TransitionKey::Symbol(1), HashMap::new()),
]),
};

assert_eq!(result, expected);
}
}
48 changes: 48 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::interegular::fsm::Fsm;
use crate::interegular::patterns::parse_pattern;
use crate::interegular::patterns::RegexElement;
use crate::json_schema;
Expand Down Expand Up @@ -456,6 +457,50 @@ pub fn parse_pattern_internal(py: Python, pattern: &str) -> PyResult<PyObject> {
}
}

#[pyclass]
pub struct InteregularFSMInfo {
#[pyo3(get)]
initial: u32,
#[pyo3(get)]
finals: HashSet<u32>,
#[pyo3(get)]
states: HashSet<u32>,
#[pyo3(get)]
map: HashMap<u32, HashMap<u32, u32>>,
}

#[pyfunction(name = "parse_pattern_to_fsm")]
#[pyo3(text_signature = "(pattern: &str)")]
pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult<InteregularFSMInfo> {
let regex_element =
parse_pattern(pattern).map_err(|_| PyValueError::new_err("Invalid pattern"))?;

let alphabet = None;
let prefix_postfix = None;
let flags = None;

let fsm_info = regex_element.to_fsm(alphabet, prefix_postfix, flags);
let map: HashMap<u32, HashMap<u32, u32>> = fsm_info
.map
.iter()
.map(|(key, map)| {
let u32_key = u32::from(*key);
let map_as_u32s = map
.iter()
.map(|(key, value)| (u32::from(*key), u32::from(*value)))
.collect();
(u32_key, map_as_u32s)
})
.collect();

Ok(InteregularFSMInfo {
initial: fsm_info.initial.into(),
finals: fsm_info.finals.iter().map(|f| (*f).into()).collect(),
states: fsm_info.states.iter().map(|s| (*s).into()).collect(),
map,
})
}

#[pymodule]
fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?;
Expand All @@ -474,6 +519,9 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAnchor>()?;
m.add_class::<PyFlag>()?;

m.add_function(wrap_pyfunction!(parse_pattern_to_fsm_internal, m)?)?;
m.add_class::<InteregularFSMInfo>()?;

m.add_class::<FSMInfo>()?;

m.add("BOOLEAN", json_schema::BOOLEAN)?;
Expand Down

0 comments on commit 729e922

Please sign in to comment.