diff --git a/Cargo.toml b/Cargo.toml index 399118a4..6bcad241 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ bincode = "2.0.0-rc.3" # Fragile dependencies, minor updates often break the code hf-hub = "=0.3.2" tokenizers = { version = "=0.20.3", features = ["http"] } +rustc-hash = "2.1.0" [features] python-bindings = ["pyo3"] diff --git a/src/index.rs b/src/index.rs index a756445c..5fcc3e93 100644 --- a/src/index.rs +++ b/src/index.rs @@ -4,24 +4,24 @@ use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; use crate::vocabulary::Vocabulary; use crate::{Error, Result}; use bincode::{Decode, Encode}; -use std::collections::{HashMap, HashSet}; +use rustc_hash::{FxHashMap, FxHashSet}; #[derive(Debug)] pub struct FSMInfo { pub(crate) initial: State, - pub(crate) finals: HashSet, - pub(crate) transitions: HashMap<(State, TransitionKey), State>, + pub(crate) finals: FxHashSet, + pub(crate) transitions: FxHashMap<(State, TransitionKey), State>, pub(crate) alphabet_anything_value: TransitionKey, - pub(crate) alphabet_symbol_mapping: HashMap, + pub(crate) alphabet_symbol_mapping: FxHashMap, } impl FSMInfo { pub fn new( initial: State, - finals: HashSet, - transitions: HashMap<(State, TransitionKey), State>, + finals: FxHashSet, + transitions: FxHashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: FxHashMap, ) -> Self { Self { initial, @@ -36,8 +36,8 @@ impl FSMInfo { #[derive(Debug, Encode, Decode)] pub struct Index { initial: u32, - finals: HashSet, - states_to_token_subsets: HashMap>, + finals: FxHashSet, + states_to_token_subsets: FxHashMap>, eos_token_id: u32, } @@ -46,11 +46,11 @@ impl Index { fsm_info: &FSMInfo, vocabulary: &Vocabulary, eos_token_id: u32, - frozen_tokens: HashSet, + frozen_tokens: FxHashSet, ) -> Result { - let mut states_to_token_subsets: HashMap> = HashMap::new(); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from([fsm_info.initial]); + let mut states_to_token_subsets: FxHashMap> = FxHashMap::default(); + let mut seen: FxHashSet = FxHashSet::default(); + let mut next_states: FxHashSet = FxHashSet::from_iter([fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -126,7 +126,7 @@ impl Index { self.finals.contains(&state) } - pub(crate) fn transitions(&self) -> &HashMap> { + pub(crate) fn transitions(&self) -> &FxHashMap> { &self.states_to_token_subsets } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 6dc7d544..76a9a1ec 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -10,21 +10,21 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; use pyo3::wrap_pyfunction; +use rustc_hash::{FxHashMap, FxHashSet}; use serde_json::Value; -use std::collections::{HashMap, HashSet}; #[pyclass(name = "FSMInfo")] pub struct PyFSMInfo { #[pyo3(get)] initial: State, #[pyo3(get)] - finals: HashSet, + finals: FxHashSet, #[pyo3(get)] - transitions: HashMap<(State, TransitionKey), State>, + transitions: FxHashMap<(State, TransitionKey), State>, #[pyo3(get)] alphabet_anything_value: TransitionKey, #[pyo3(get)] - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: FxHashMap, } impl From for PyFSMInfo { @@ -57,10 +57,10 @@ impl PyFSMInfo { #[new] fn new( initial: State, - finals: HashSet, - transitions: HashMap<(State, TransitionKey), State>, + finals: FxHashSet, + transitions: FxHashMap<(State, TransitionKey), State>, alphabet_anything_value: TransitionKey, - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: FxHashMap, ) -> Self { FSMInfo::new( initial, @@ -83,7 +83,7 @@ impl PyIndex { fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, eos_token_id: u32, - frozen_tokens: HashSet, + frozen_tokens: FxHashSet, ) -> PyResult { Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) .map(PyIndex) @@ -123,7 +123,7 @@ impl PyIndex { self.0.is_final(state) } - fn get_transitions(&self) -> HashMap> { + fn get_transitions(&self) -> FxHashMap> { self.0.transitions().clone() } @@ -155,9 +155,9 @@ pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyR text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] pub fn walk_fsm_py( - fsm_transitions: HashMap<(State, TransitionKey), State>, + fsm_transitions: FxHashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: HashSet, + fsm_finals: FxHashSet, token_transition_keys: Vec, start_state: State, full_match: bool, @@ -177,13 +177,13 @@ pub fn walk_fsm_py( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] pub fn state_scan_tokens_py( - fsm_transitions: HashMap<(State, TransitionKey), State>, + fsm_transitions: FxHashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: HashSet, + fsm_finals: FxHashSet, vocabulary: &PyVocabulary, - vocabulary_transition_keys: HashMap>, + vocabulary_transition_keys: FxHashMap>, start_state: State, -) -> PyResult> { +) -> PyResult> { Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -197,7 +197,7 @@ pub fn state_scan_tokens_py( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] pub fn get_token_transition_keys_py( - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: FxHashMap, alphabet_anything_value: TransitionKey, token_str: String, ) -> PyResult> { @@ -213,11 +213,11 @@ pub fn get_token_transition_keys_py( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] pub fn get_vocabulary_transition_keys_py( - alphabet_symbol_mapping: HashMap, + alphabet_symbol_mapping: FxHashMap, alphabet_anything_value: TransitionKey, vocabulary: &PyVocabulary, - frozen_tokens: HashSet, -) -> PyResult>> { + frozen_tokens: FxHashSet, +) -> PyResult>> { Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -232,11 +232,11 @@ pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, - frozen_tokens: HashSet, + frozen_tokens: FxHashSet, ) -> PyResult> { let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + let mut seen: FxHashSet = FxHashSet::default(); + let mut next_states: FxHashSet = FxHashSet::from_iter(vec![fsm_info.initial]); let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, @@ -284,7 +284,7 @@ pub struct PyVocabulary(Vocabulary); #[pymethods] impl PyVocabulary { #[staticmethod] - fn from_dict(map: HashMap>) -> PyVocabulary { + fn from_dict(map: FxHashMap>) -> PyVocabulary { PyVocabulary(Vocabulary::from(map)) } diff --git a/src/regex.rs b/src/regex.rs index b5658191..24687f1e 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,10 +1,10 @@ use crate::prelude::*; -use std::collections::{HashMap, HashSet}; +use rustc_hash::{FxHashMap, FxHashSet}; pub fn walk_fsm( - fsm_transitions: &HashMap<(State, TransitionKey), State>, + fsm_transitions: &FxHashMap<(State, TransitionKey), State>, _fsm_initial: State, - fsm_finals: &HashSet, + fsm_finals: &FxHashSet, token_transition_keys: &[TransitionKey], start_state: State, full_match: bool, @@ -39,14 +39,14 @@ pub fn walk_fsm( } pub fn state_scan_tokens( - fsm_transitions: &HashMap<(State, TransitionKey), State>, + fsm_transitions: &FxHashMap<(State, TransitionKey), State>, fsm_initial: State, - fsm_finals: &HashSet, + fsm_finals: &FxHashSet, vocabulary: &Vocabulary, - vocabulary_transition_keys: &HashMap>, + vocabulary_transition_keys: &FxHashMap>, start_state: State, -) -> HashSet<(TokenId, State)> { - let mut res = HashSet::new(); +) -> FxHashSet<(TokenId, State)> { + let mut res = FxHashSet::default(); for (token, token_ids) in vocabulary.iter() { let token_transition_keys = &vocabulary_transition_keys[token]; @@ -72,7 +72,7 @@ pub fn state_scan_tokens( } pub fn get_token_transition_keys( - alphabet_symbol_mapping: &HashMap, + alphabet_symbol_mapping: &FxHashMap, alphabet_anything_value: TransitionKey, token_str: &str, ) -> Vec { @@ -105,12 +105,12 @@ pub fn get_token_transition_keys( } pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: &HashMap, + alphabet_symbol_mapping: &FxHashMap, alphabet_anything_value: TransitionKey, vocabulary: &Vocabulary, - frozen_tokens: &HashSet, -) -> HashMap> { - let mut vocab_transition_keys = HashMap::new(); + frozen_tokens: &FxHashSet, +) -> FxHashMap> { + let mut vocab_transition_keys = FxHashMap::default(); for item in vocabulary.iter() { let token_str = item.0.clone(); diff --git a/src/vocabulary/mod.rs b/src/vocabulary/mod.rs index 719c9040..0b3eaa2e 100644 --- a/src/vocabulary/mod.rs +++ b/src/vocabulary/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use rustc_hash::FxHashMap; use tokenizers::normalizers::Sequence; use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer}; @@ -29,7 +29,7 @@ mod processor; pub struct Vocabulary { // TODO: Option is temp for back compatibility eos_token_id: Option, - tokens: HashMap>, + tokens: FxHashMap>, } impl Vocabulary { @@ -37,7 +37,7 @@ impl Vocabulary { pub fn new(eos_token_id: Option) -> Self { Self { eos_token_id, - tokens: HashMap::new(), + tokens: FxHashMap::default(), } } @@ -174,9 +174,9 @@ impl Vocabulary { } impl std::ops::Deref for Vocabulary { - type Target = HashMap>; + type Target = FxHashMap>; - fn deref(&self) -> &HashMap> { + fn deref(&self) -> &FxHashMap> { &self.tokens } } @@ -194,8 +194,8 @@ impl std::fmt::Display for Vocabulary { } } -impl From>> for Vocabulary { - fn from(tokens: HashMap>) -> Vocabulary { +impl From>> for Vocabulary { + fn from(tokens: FxHashMap>) -> Vocabulary { Vocabulary { eos_token_id: None, tokens, @@ -257,7 +257,7 @@ mod tests { #[test] fn new_empty_vocabulary_from_hashmap() { - let map = HashMap::new(); + let map = FxHashMap::default(); let vocabulary = Vocabulary::from(map); assert!(vocabulary.eos_token_id.is_none()); assert!(vocabulary.tokens.is_empty());