diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index f970980b..7933de90 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple class FSMInfo: initial: int @@ -52,6 +52,7 @@ def create_fsm_index_end_to_end( vocabulary: List[Tuple[str, List[int]]], frozen_tokens: frozenset[str], ) -> Dict[int, Dict[int, int]]: ... +def parse_pattern(pattern: str) -> Any: ... BOOLEAN: str DATE: str diff --git a/src/interegular/fsm.rs b/src/interegular/fsm.rs index f217bc64..0371e45d 100644 --- a/src/interegular/fsm.rs +++ b/src/interegular/fsm.rs @@ -25,10 +25,19 @@ impl From for usize { } } +impl From for u32 { + fn from(c: TransitionKey) -> Self { + match c { + TransitionKey::Symbol(i) => i as u32, + _ => panic!("Cannot convert `anything else` to u32"), + } + } +} + pub trait SymbolTrait: Eq + Hash + Clone + Debug + From {} impl> SymbolTrait for T {} -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Alphabet { pub symbol_mapping: HashMap, pub by_transition: HashMap>, @@ -49,6 +58,14 @@ impl Alphabet { } } + #[must_use] + pub fn empty() -> Self { + Alphabet { + symbol_mapping: HashMap::new(), + by_transition: HashMap::new(), + } + } + pub fn get(&self, item: &T) -> TransitionKey { match self.symbol_mapping.get(item) { Some(x) => *x, @@ -119,9 +136,9 @@ impl Alphabet { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Fsm { - alphabet: Alphabet, + pub alphabet: Alphabet, pub states: HashSet, pub initial: TransitionKey, pub finals: HashSet, diff --git a/src/interegular/patterns.rs b/src/interegular/patterns.rs index 5f5840cb..8e97604f 100644 --- a/src/interegular/patterns.rs +++ b/src/interegular/patterns.rs @@ -6,6 +6,7 @@ use std::rc::Rc; use std::vec; use crate::interegular::fsm::SymbolTrait; +use crate::interegular::fsm::TransitionKey; use crate::interegular::fsm::{Alphabet, Fsm}; const SPECIAL_CHARS_INNER: [&str; 2] = ["\\", "]"]; @@ -136,6 +137,34 @@ impl RegexElement { flags: Option>, ) -> Fsm { match self { + RegexElement::Literal(c) => { + let alphabet = alphabet + .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); + let prefix_postfix = prefix_postfix.unwrap_or_else(|| self.get_prefix_postfix()); + + let case_insensitive = flags + .clone() + .as_ref() + .map_or(false, |f| f.contains(&Flag::CaseInsensitive)); + + let mut mapping = HashMap::<_, HashMap<_, _>>::new(); + let symbol = alphabet.get(c); + + let mut m = std::collections::HashMap::new(); + m.insert(symbol, TransitionKey::Symbol(1_usize)); + mapping.insert(TransitionKey::Symbol(0_usize), m); + + let states = (0..=1).map(std::convert::Into::into).collect(); + let finals = (1..=1).map(std::convert::Into::into).collect(); + + Fsm::new( + alphabet, + states, // {0, 1} + 0.into(), + finals, // {1} + mapping, + ) + } RegexElement::CharGroup { chars, inverted } => { let alphabet = alphabet .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); @@ -151,7 +180,41 @@ impl RegexElement { .as_ref() .map_or(false, |f| f.contains(&Flag::CaseInsensitive)); - let mapping = HashMap::<_, HashMap<_, _>>::new(); + let mut mapping = HashMap::<_, HashMap<_, _>>::new(); + + if *inverted { + let chars = chars.clone(); + let alphabet = alphabet.clone(); + let alphabet_set = alphabet + .clone() + .by_transition + .keys() + .copied() + .collect::>(); + + let char_as_usize = chars + .iter() + .map(|c| TransitionKey::Symbol(*c as usize)) + .collect(); + let diff = alphabet_set + .difference(&char_as_usize) + .copied() + .collect::>(); + + let mut m = std::collections::HashMap::new(); + for symbol in diff { + m.insert(symbol, TransitionKey::Symbol(1_usize)); + } + mapping.insert(TransitionKey::Symbol(0_usize), m); + } else { + let chars = chars.clone(); + for symbol in chars { + let mut m = std::collections::HashMap::new(); + let symbol_value = alphabet.get(&symbol); + m.insert(symbol_value, TransitionKey::Symbol(1_usize)); + mapping.insert(TransitionKey::Symbol(0_usize), m); + } + } let states = (0..=1).map(std::convert::Into::into).collect(); let finals = (1..=1).map(std::convert::Into::into).collect(); @@ -164,7 +227,62 @@ impl RegexElement { mapping, ) } - // Implement other variants as needed + RegexElement::Repeated { element, min, max } => { + let unit = element.to_fsm(alphabet.clone(), None, flags.clone()); + let alphabet = alphabet + .unwrap_or_else(|| self.get_alphabet(&flags.clone().unwrap_or_default())); + let mandatory = std::iter::repeat(unit.clone()).take(*min).fold( + Fsm::new( + // TODO: fix if alphabet is None + alphabet.clone(), + HashSet::new(), + 0.into(), + HashSet::new(), + std::collections::HashMap::new(), + ), + |acc, f| Fsm::concatenate(&[acc, f]), + ); + + let optional = if max.is_none() { + unit.star() + } else { + let mut optional = unit.clone(); + optional.finals.insert(optional.initial); + optional = std::iter::repeat(optional.clone()) + .take(max.unwrap() - min) + .fold( + Fsm::new( + alphabet.clone(), + HashSet::new(), + 0.into(), + HashSet::new(), + std::collections::HashMap::new(), + ), + |acc, f| Fsm::concatenate(&[acc, f]), + ); + + optional + }; + + Fsm::concatenate(&[mandatory, optional]) + } + RegexElement::Concatenation(parts) => { + let mut current = vec![]; + for part in parts { + current.push(part.to_fsm(alphabet.clone(), None, flags.clone())); + } + + Fsm::concatenate(¤t) + } + RegexElement::Alternation(options) => { + let mut current = vec![]; + for option in options { + current.push(option.to_fsm(alphabet.clone(), None, flags.clone())); + } + + Fsm::union(¤t) + } + // throw on non implemented variants _ => unimplemented!("FSM conversion not implemented for this variant"), } } @@ -186,6 +304,25 @@ impl RegexElement { Alphabet::from_groups(&[relevant, HashSet::from(['\0'.into()])]) } RegexElement::Literal(c) => Alphabet::from_groups(&[HashSet::from([(*c).into()])]), + RegexElement::Repeated { element, .. } => element.get_alphabet(flags), + RegexElement::Alternation(options) => { + let mut alphabet = Alphabet::empty(); + for option in options { + let alphabets = vec![alphabet, option.get_alphabet(flags)]; + let (res, new_to_old) = Alphabet::union(alphabets.as_slice()); + alphabet = res; + } + alphabet + } + RegexElement::Concatenation(parts) => { + let mut alphabet = Alphabet::empty(); + for part in parts { + let alphabets = vec![alphabet, part.get_alphabet(flags)]; + let (res, new_to_old) = Alphabet::union(alphabets.as_slice()); + alphabet = res; + } + alphabet + } _ => unimplemented!("Alphabet not implemented for this variant"), } } @@ -685,7 +822,7 @@ mod tests { #[test] fn test_parse_pattern_simple() { - let pattern = "a"; + let pattern: &str = "a"; let result = parse_pattern(pattern); assert_eq!( result, @@ -960,4 +1097,29 @@ mod tests { let result = parse_pattern(pattern); assert!(result.is_err()); } + #[test] + fn test_parse_pattern_simple_to_fsm() { + let pattern: &str = "a"; + let result = parse_pattern(pattern).unwrap(); + let result = result.to_fsm(None, None, None); + + let expected = Fsm { + alphabet: Alphabet { + symbol_mapping: HashMap::from([('a', TransitionKey::Symbol(0))]), + by_transition: HashMap::from([(TransitionKey::Symbol(0), vec!['a'])]), + }, + states: HashSet::from([TransitionKey::Symbol(0), TransitionKey::Symbol(1)]), + initial: TransitionKey::Symbol(0), + finals: HashSet::from([TransitionKey::Symbol(1)]), + map: HashMap::from([ + ( + TransitionKey::Symbol(0), + HashMap::from([(TransitionKey::Symbol(0), TransitionKey::Symbol(1))]), + ), + (TransitionKey::Symbol(1), HashMap::new()), + ]), + }; + + assert_eq!(result, expected); + } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 16600b6b..05c7eee9 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,3 +1,4 @@ +use crate::interegular::fsm::Fsm; use crate::interegular::patterns::parse_pattern; use crate::interegular::patterns::RegexElement; use crate::json_schema; @@ -456,6 +457,50 @@ pub fn parse_pattern_internal(py: Python, pattern: &str) -> PyResult { } } +#[pyclass] +pub struct InteregularFSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + states: HashSet, + #[pyo3(get)] + map: HashMap>, +} + +#[pyfunction(name = "parse_pattern_to_fsm")] +#[pyo3(text_signature = "(pattern: &str)")] +pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult { + let regex_element = + parse_pattern(pattern).map_err(|_| PyValueError::new_err("Invalid pattern"))?; + + let alphabet = None; + let prefix_postfix = None; + let flags = None; + + let fsm_info = regex_element.to_fsm(alphabet, prefix_postfix, flags); + let map: HashMap> = fsm_info + .map + .iter() + .map(|(key, map)| { + let u32_key = u32::from(*key); + let map_as_u32s = map + .iter() + .map(|(key, value)| (u32::from(*key), u32::from(*value))) + .collect(); + (u32_key, map_as_u32s) + }) + .collect(); + + Ok(InteregularFSMInfo { + initial: fsm_info.initial.into(), + finals: fsm_info.finals.iter().map(|f| (*f).into()).collect(), + states: fsm_info.states.iter().map(|s| (*s).into()).collect(), + map, + }) +} + #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; @@ -474,6 +519,9 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_function(wrap_pyfunction!(parse_pattern_to_fsm_internal, m)?)?; + m.add_class::()?; + m.add_class::()?; m.add("BOOLEAN", json_schema::BOOLEAN)?;