Skip to content

Commit

Permalink
Cleaner from_regex logic
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 13, 2024
1 parent ac33a29 commit 52b1093
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub enum Error {
#[error("Failed to build DFA {0}")]
IndexDfaError(#[from] Box<regex_automata::dfa::dense::BuildError>),
#[error("Index failed since anchored universal start state doesn't exist")]
IndexNoAnchoredUniversalStartState,
DfaHasNoStartState,
#[error(transparent)]
TokenizersError(#[from] tokenizers::Error),
#[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")]
Expand Down
108 changes: 52 additions & 56 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ use bincode::{Decode, Encode};
use regex_automata::dfa::{dense::DFA, Automaton};
use regex_automata::util::primitives::StateID as AutomataStateId;
use regex_automata::Anchored;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};

#[derive(Debug)]
pub struct FSMInfo {
pub(crate) initial: State,
pub(crate) finals: FxHashSet<State>,
pub(crate) transitions: FxHashMap<(State, TransitionKey), State>,
pub(crate) finals: HashSet<State>,
pub(crate) transitions: HashMap<(State, TransitionKey), State>,
pub(crate) alphabet_anything_value: TransitionKey,
pub(crate) alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
pub(crate) alphabet_symbol_mapping: HashMap<String, TransitionKey>,
}

impl FSMInfo {
pub fn new(
initial: State,
finals: FxHashSet<State>,
transitions: FxHashMap<(State, TransitionKey), State>,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
) -> Self {
Self {
initial,
Expand All @@ -39,8 +39,8 @@ impl FSMInfo {
#[derive(Debug, Encode, Decode)]
pub struct Index {
initial: u32,
finals: FxHashSet<u32>,
states_to_token_subsets: FxHashMap<u32, FxHashMap<u32, u32>>,
finals: HashSet<u32>,
states_to_token_subsets: HashMap<u32, HashMap<u32, u32>>,
eos_token_id: u32,
}

Expand All @@ -49,11 +49,11 @@ impl Index {
fsm_info: &FSMInfo,
vocabulary: &Vocabulary,
eos_token_id: u32,
frozen_tokens: FxHashSet<String>,
frozen_tokens: HashSet<String>,
) -> Result<Self> {
let mut states_to_token_subsets: FxHashMap<u32, FxHashMap<u32, u32>> = FxHashMap::default();
let mut seen: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<State> = FxHashSet::from_iter([fsm_info.initial]);
let mut states_to_token_subsets: HashMap<u32, HashMap<u32, u32>> = HashMap::default();
let mut seen: HashSet<State> = HashSet::default();
let mut next_states: HashSet<State> = HashSet::from_iter([fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down Expand Up @@ -111,85 +111,77 @@ impl Index {
pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let eos_token_id = match vocabulary.eos_token_id() {
Some(s) => s,
// TODO: this error will be removed once eos_token_id for vocabulary won't be optional
None => return Err(Error::IndexEosTokenIdNotAvailable),
};

let dfa = DFA::builder().build(regex).map_err(Box::new)?;
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Some(s) => s,
None => return Err(Error::IndexNoAnchoredUniversalStartState),
None => return Err(Error::DfaHasNoStartState),
};

let mut index: FxHashMap<State, FxHashMap<TokenId, State>> = FxHashMap::default();
let mut seen: FxHashSet<AutomataStateId> = FxHashSet::default();
let mut final_states: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<AutomataStateId> = FxHashSet::from_iter([start_state]);
let mut transitions: HashMap<State, HashMap<TokenId, State>> = HashMap::default();
let mut final_states: HashSet<State> = HashSet::default();

while let Some(start_state) = next_states.iter().cloned().next() {
next_states.remove(&start_state);
seen.insert(start_state);
let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
let mut next_states: Vec<AutomataStateId> = vec![start_state];

if dfa.is_match_state(dfa.next_eoi_state(start_state)) {
final_states.insert(start_state.as_u32());
while let Some(current_state) = next_states.pop() {
if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
final_states.insert(current_state.as_u32());
}

'token_loop: for (token, ids) in vocabulary.tokens_to_ids().iter() {
if ids.contains(&eos_token_id) {
continue;
}

let mut next_state = start_state;
let mut next_state = current_state;
for transition_byte in token.as_bytes() {
next_state = dfa.next_state(next_state, *transition_byte);
if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
continue 'token_loop;
}
}

if dfa.is_match_state(next_state) {
// Token either matched or matched except the last character.
// Check what happens if the input suddenly ends after reaching this state.
// If the automata still matches, then token is exactly matched, if not
// then token didn't match.
let next_eoi_state = dfa.next_eoi_state(next_state);
let token_matched = dfa.is_match_state(next_eoi_state);
if !token_matched {
continue;
let is_intermediate_state = !dfa.is_match_state(next_state);
let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
if is_intermediate_state || is_full_match_state {
for token_id in ids {
transitions
.entry(current_state.as_u32())
.or_default()
.insert(*token_id, next_state.as_u32());
}
}

for token_id in ids {
let mapping = index.entry(start_state.as_u32()).or_default();
mapping.insert(*token_id, next_state.as_u32());

if !seen.contains(&next_state) {
next_states.insert(next_state);
}
if !seen.contains(&next_state) {
seen.insert(next_state);
next_states.push(next_state);
}
}
}

let start_state = start_state.as_u32();

// Populate `index` with mappings from `final_states` to `eos_token_id`
// Populate `transitions` with mappings from `final_states` to `eos_token_id`
for &final_state in &final_states {
index
transitions
.entry(final_state)
.or_default()
.insert(eos_token_id, final_state);
}

// Check if there is at least one valid mapping
let is_valid = index.values().any(|mapping| {
let is_valid = transitions.values().any(|mapping| {
mapping
.values()
.any(|end_state| final_states.contains(end_state))
});

if is_valid {
Ok(Self {
initial: start_state,
initial: start_state.as_u32(),
finals: final_states,
states_to_token_subsets: index,
states_to_token_subsets: transitions,
eos_token_id,
})
} else {
Expand Down Expand Up @@ -218,11 +210,11 @@ impl Index {
self.finals.contains(&state)
}

pub(crate) fn final_states(&self) -> &FxHashSet<State> {
pub(crate) fn final_states(&self) -> &HashSet<State> {
&self.finals
}

pub(crate) fn transitions(&self) -> &FxHashMap<u32, FxHashMap<u32, u32>> {
pub(crate) fn transitions(&self) -> &HashMap<u32, HashMap<u32, u32>> {
&self.states_to_token_subsets
}
}
Expand All @@ -243,10 +235,14 @@ mod tests {

let index = Index::from_regex(regex, &vocabulary).expect("Index failed");
assert_eq!(index.initial(), 40);
assert_eq!(index.final_states(), &FxHashSet::from_iter([24, 48, 56]));
assert_eq!(
"{24: {3: 24, 4: 24, 2: 24}, 48: {4: 48}, 40: {3: 48, 2: 56}, 56: {3: 24, 4: 56, 2: 24}}",
format!("{:?}", index.transitions())
);
assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));

let expected: HashMap<u32, HashMap<u32, u32>> = HashMap::from_iter([
(24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
(48, HashMap::from_iter([(4, 48)])),
(40, HashMap::from_iter([(3, 48), (2, 56)])),
(56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
]);
assert_eq!(&expected, index.transitions());
}
}
2 changes: 1 addition & 1 deletion src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use serde_json::Value;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use serde_json::Value;

#[pyclass(name = "FSMInfo")]
pub struct PyFSMInfo {
Expand Down

0 comments on commit 52b1093

Please sign in to comment.