diff --git a/src/lib.rs b/src/lib.rs index 0bc900d2..5811ff7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,3 +3,7 @@ pub mod regex; #[cfg(feature = "python-bindings")] mod python_bindings; + +mod primitives; + +pub use crate::primitives::{State, TokenId, TransitionKey}; diff --git a/src/primitives.rs b/src/primitives.rs new file mode 100644 index 00000000..bbc77700 --- /dev/null +++ b/src/primitives.rs @@ -0,0 +1,8 @@ +/// Interegular transition key. +pub type TransitionKey = u32; + +/// Token identifier. +pub type TokenId = u32; + +/// Interegular state. +pub type State = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 22bebe62..251300f6 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -3,6 +3,7 @@ use crate::regex::get_token_transition_keys; use crate::regex::get_vocabulary_transition_keys; use crate::regex::state_scan_tokens; use crate::regex::walk_fsm; +use crate::*; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -13,26 +14,26 @@ use std::collections::{HashMap, HashSet}; #[pyclass] pub struct FSMInfo { #[pyo3(get)] - initial: u32, + initial: State, #[pyo3(get)] - finals: HashSet, + finals: HashSet, #[pyo3(get)] - transitions: HashMap<(u32, u32), u32>, + transitions: HashMap<(State, TransitionKey), State>, #[pyo3(get)] - alphabet_anything_value: u32, + alphabet_anything_value: TransitionKey, #[pyo3(get)] - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: HashMap, } #[pymethods] impl FSMInfo { #[new] fn new( - initial: u32, - finals: HashSet, - transitions: HashMap<(u32, u32), u32>, - alphabet_anything_value: u32, - alphabet_symbol_mapping: HashMap, + initial: State, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, + alphabet_anything_value: TransitionKey, + alphabet_symbol_mapping: HashMap, ) -> Self { Self { initial, @@ -67,13 +68,13 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] pub fn walk_fsm_py( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - token_transition_keys: Vec, - start_state: u32, + fsm_transitions: HashMap<(State, TransitionKey), State>, + fsm_initial: State, + fsm_finals: HashSet, + token_transition_keys: Vec, + start_state: State, full_match: bool, -) -> PyResult> { +) -> PyResult> { Ok(walk_fsm( &fsm_transitions, fsm_initial, @@ -89,13 +90,13 @@ pub fn walk_fsm_py( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] pub fn state_scan_tokens_py( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, - start_state: u32, -) -> PyResult> { + fsm_transitions: HashMap<(State, TransitionKey), State>, + fsm_initial: State, + fsm_finals: HashSet, + vocabulary: Vec<(String, Vec)>, + vocabulary_transition_keys: Vec>, + start_state: State, +) -> PyResult> { Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -109,10 +110,10 @@ pub fn state_scan_tokens_py( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: TransitionKey, token_str: String, -) -> PyResult> { +) -> PyResult> { Ok(get_token_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -125,11 +126,11 @@ pub fn get_token_transition_keys_py( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - vocabulary: Vec<(String, Vec)>, + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: TransitionKey, + vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, -) -> PyResult>> { +) -> PyResult>> { Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -143,12 +144,12 @@ pub fn get_vocabulary_transition_keys_py( pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &FSMInfo, - vocabulary: Vec<(String, Vec)>, + vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult> { let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, diff --git a/src/regex.rs b/src/regex.rs index 1db920ac..314ff0cf 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,13 +1,14 @@ +use crate::*; use std::collections::{HashMap, HashSet}; pub fn walk_fsm( - fsm_transitions: &HashMap<(u32, u32), u32>, - _fsm_initial: u32, - fsm_finals: &HashSet, - token_transition_keys: &[u32], - start_state: u32, + fsm_transitions: &HashMap<(State, TransitionKey), State>, + _fsm_initial: State, + fsm_finals: &HashSet, + token_transition_keys: &[TransitionKey], + start_state: State, full_match: bool, -) -> Vec { +) -> Vec { let mut state = start_state; let mut accepted_states = Vec::new(); let mut last_final_idx = 0; @@ -38,19 +39,19 @@ pub fn walk_fsm( } pub fn state_scan_tokens( - fsm_transitions: &HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: &HashSet, - vocabulary: &[(String, Vec)], - vocabulary_transition_keys: &[Vec], - start_state: u32, -) -> HashSet<(u32, u32)> { + fsm_transitions: &HashMap<(State, TransitionKey), State>, + fsm_initial: State, + fsm_finals: &HashSet, + vocabulary: &[(String, Vec)], + vocabulary_transition_keys: &[Vec], + start_state: State, +) -> HashSet<(TokenId, State)> { let mut res = HashSet::new(); for (vocab_item, token_transition_keys) in vocabulary.iter().zip(vocabulary_transition_keys.iter()) { - let token_ids: Vec = vocab_item.1.clone(); + let token_ids: Vec = vocab_item.1.clone(); let state_seq = walk_fsm( fsm_transitions, @@ -74,10 +75,10 @@ pub fn state_scan_tokens( } pub fn get_token_transition_keys( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: u32, + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: TransitionKey, token_str: &str, -) -> Vec { +) -> Vec { let mut token_transition_keys = Vec::new(); let mut i = 0; let chars: Vec = token_str.chars().collect(); @@ -107,12 +108,12 @@ pub fn get_token_transition_keys( } pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &HashMap, - alphabet_anything_value: u32, - vocabulary: &[(String, Vec)], + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: TransitionKey, + vocabulary: &[(String, Vec)], frozen_tokens: &HashSet, -) -> Vec> { - let mut vocab_transition_keys: Vec> = Vec::new(); +) -> Vec> { + let mut vocab_transition_keys: Vec> = Vec::new(); for item in vocabulary.iter() { let token_str = item.0.clone();