Skip to content

Commit

Permalink
Switch from HashMap -> FxHashMap for better perf
Browse files Browse the repository at this point in the history
  • Loading branch information
unaidedelf8777 authored and torymur committed Dec 9, 2024
1 parent f07df79 commit 21d6bb6
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 58 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ bincode = "2.0.0-rc.3"
# Fragile dependencies, minor updates often break the code
hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.3", features = ["http"] }
rustc-hash = "2.1.0"

[features]
python-bindings = ["pyo3"]
Expand Down
28 changes: 14 additions & 14 deletions src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@ use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
use bincode::{Decode, Encode};
use std::collections::{HashMap, HashSet};
use rustc_hash::{FxHashMap, FxHashSet};

#[derive(Debug)]
pub struct FSMInfo {
pub(crate) initial: State,
pub(crate) finals: HashSet<State>,
pub(crate) transitions: HashMap<(State, TransitionKey), State>,
pub(crate) finals: FxHashSet<State>,
pub(crate) transitions: FxHashMap<(State, TransitionKey), State>,
pub(crate) alphabet_anything_value: TransitionKey,
pub(crate) alphabet_symbol_mapping: HashMap<String, TransitionKey>,
pub(crate) alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
}

impl FSMInfo {
pub fn new(
initial: State,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
finals: FxHashSet<State>,
transitions: FxHashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
) -> Self {
Self {
initial,
Expand All @@ -36,8 +36,8 @@ impl FSMInfo {
#[derive(Debug, Encode, Decode)]
pub struct Index {
initial: u32,
finals: HashSet<u32>,
states_to_token_subsets: HashMap<u32, HashMap<u32, u32>>,
finals: FxHashSet<u32>,
states_to_token_subsets: FxHashMap<u32, FxHashMap<u32, u32>>,
eos_token_id: u32,
}

Expand All @@ -46,11 +46,11 @@ impl Index {
fsm_info: &FSMInfo,
vocabulary: &Vocabulary,
eos_token_id: u32,
frozen_tokens: HashSet<String>,
frozen_tokens: FxHashSet<String>,
) -> Result<Self> {
let mut states_to_token_subsets: HashMap<u32, HashMap<u32, u32>> = HashMap::new();
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from([fsm_info.initial]);
let mut states_to_token_subsets: FxHashMap<u32, FxHashMap<u32, u32>> = FxHashMap::default();
let mut seen: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<State> = FxHashSet::from_iter([fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down Expand Up @@ -126,7 +126,7 @@ impl Index {
self.finals.contains(&state)
}

pub(crate) fn transitions(&self) -> &HashMap<u32, HashMap<u32, u32>> {
pub(crate) fn transitions(&self) -> &FxHashMap<u32, FxHashMap<u32, u32>> {
&self.states_to_token_subsets
}
}
46 changes: 23 additions & 23 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use rustc_hash::{FxHashMap, FxHashSet};
use serde_json::Value;
use std::collections::{HashMap, HashSet};

#[pyclass(name = "FSMInfo")]
pub struct PyFSMInfo {
#[pyo3(get)]
initial: State,
#[pyo3(get)]
finals: HashSet<State>,
finals: FxHashSet<State>,
#[pyo3(get)]
transitions: HashMap<(State, TransitionKey), State>,
transitions: FxHashMap<(State, TransitionKey), State>,
#[pyo3(get)]
alphabet_anything_value: TransitionKey,
#[pyo3(get)]
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
}

impl From<FSMInfo> for PyFSMInfo {
Expand Down Expand Up @@ -57,10 +57,10 @@ impl PyFSMInfo {
#[new]
fn new(
initial: State,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
finals: FxHashSet<State>,
transitions: FxHashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
) -> Self {
FSMInfo::new(
initial,
Expand All @@ -83,7 +83,7 @@ impl PyIndex {
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
eos_token_id: u32,
frozen_tokens: HashSet<String>,
frozen_tokens: FxHashSet<String>,
) -> PyResult<Self> {
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
.map(PyIndex)
Expand Down Expand Up @@ -123,7 +123,7 @@ impl PyIndex {
self.0.is_final(state)
}

fn get_transitions(&self) -> HashMap<u32, HashMap<u32, u32>> {
fn get_transitions(&self) -> FxHashMap<u32, FxHashMap<u32, u32>> {
self.0.transitions().clone()
}

Expand Down Expand Up @@ -155,9 +155,9 @@ pub fn to_regex_py(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyR
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)"
)]
pub fn walk_fsm_py(
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_transitions: FxHashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
fsm_finals: FxHashSet<State>,
token_transition_keys: Vec<TransitionKey>,
start_state: State,
full_match: bool,
Expand All @@ -177,13 +177,13 @@ pub fn walk_fsm_py(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)"
)]
pub fn state_scan_tokens_py(
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_transitions: FxHashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
fsm_finals: FxHashSet<State>,
vocabulary: &PyVocabulary,
vocabulary_transition_keys: HashMap<String, Vec<TransitionKey>>,
vocabulary_transition_keys: FxHashMap<String, Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
) -> PyResult<FxHashSet<(TokenId, State)>> {
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand All @@ -197,7 +197,7 @@ pub fn state_scan_tokens_py(
#[pyfunction(name = "get_token_transition_keys")]
#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")]
pub fn get_token_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: String,
) -> PyResult<Vec<TransitionKey>> {
Expand All @@ -213,11 +213,11 @@ pub fn get_token_transition_keys_py(
text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)"
)]
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_symbol_mapping: FxHashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &PyVocabulary,
frozen_tokens: HashSet<String>,
) -> PyResult<HashMap<String, Vec<TransitionKey>>> {
frozen_tokens: FxHashSet<String>,
) -> PyResult<FxHashMap<String, Vec<TransitionKey>>> {
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -232,11 +232,11 @@ pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
frozen_tokens: HashSet<String>,
frozen_tokens: FxHashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);
let mut seen: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<State> = FxHashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down Expand Up @@ -284,7 +284,7 @@ pub struct PyVocabulary(Vocabulary);
#[pymethods]
impl PyVocabulary {
#[staticmethod]
fn from_dict(map: HashMap<Token, Vec<TokenId>>) -> PyVocabulary {
fn from_dict(map: FxHashMap<Token, Vec<TokenId>>) -> PyVocabulary {
PyVocabulary(Vocabulary::from(map))
}

Expand Down
26 changes: 13 additions & 13 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::prelude::*;
use std::collections::{HashMap, HashSet};
use rustc_hash::{FxHashMap, FxHashSet};

pub fn walk_fsm(
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_transitions: &FxHashMap<(State, TransitionKey), State>,
_fsm_initial: State,
fsm_finals: &HashSet<State>,
fsm_finals: &FxHashSet<State>,
token_transition_keys: &[TransitionKey],
start_state: State,
full_match: bool,
Expand Down Expand Up @@ -39,14 +39,14 @@ pub fn walk_fsm(
}

pub fn state_scan_tokens(
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_transitions: &FxHashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &HashSet<State>,
fsm_finals: &FxHashSet<State>,
vocabulary: &Vocabulary,
vocabulary_transition_keys: &HashMap<Token, Vec<TransitionKey>>,
vocabulary_transition_keys: &FxHashMap<Token, Vec<TransitionKey>>,
start_state: State,
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::new();
) -> FxHashSet<(TokenId, State)> {
let mut res = FxHashSet::default();

for (token, token_ids) in vocabulary.iter() {
let token_transition_keys = &vocabulary_transition_keys[token];
Expand All @@ -72,7 +72,7 @@ pub fn state_scan_tokens(
}

pub fn get_token_transition_keys(
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_symbol_mapping: &FxHashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: &str,
) -> Vec<TransitionKey> {
Expand Down Expand Up @@ -105,12 +105,12 @@ pub fn get_token_transition_keys(
}

pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_symbol_mapping: &FxHashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &Vocabulary,
frozen_tokens: &HashSet<String>,
) -> HashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = HashMap::new();
frozen_tokens: &FxHashSet<String>,
) -> FxHashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = FxHashMap::default();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand Down
16 changes: 8 additions & 8 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use rustc_hash::FxHashMap;

use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};
Expand Down Expand Up @@ -29,15 +29,15 @@ mod processor;
pub struct Vocabulary {
// TODO: Option is temp for back compatibility
eos_token_id: Option<TokenId>,
tokens: HashMap<Token, Vec<TokenId>>,
tokens: FxHashMap<Token, Vec<TokenId>>,
}

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new(eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
tokens: HashMap::new(),
tokens: FxHashMap::default(),
}
}

Expand Down Expand Up @@ -174,9 +174,9 @@ impl Vocabulary {
}

impl std::ops::Deref for Vocabulary {
type Target = HashMap<Token, Vec<TokenId>>;
type Target = FxHashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
fn deref(&self) -> &FxHashMap<Token, Vec<TokenId>> {
&self.tokens
}
}
Expand All @@ -194,8 +194,8 @@ impl std::fmt::Display for Vocabulary {
}
}

impl From<HashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(tokens: HashMap<Token, Vec<TokenId>>) -> Vocabulary {
impl From<FxHashMap<Token, Vec<TokenId>>> for Vocabulary {
fn from(tokens: FxHashMap<Token, Vec<TokenId>>) -> Vocabulary {
Vocabulary {
eos_token_id: None,
tokens,
Expand Down Expand Up @@ -257,7 +257,7 @@ mod tests {

#[test]
fn new_empty_vocabulary_from_hashmap() {
let map = HashMap::new();
let map = FxHashMap::default();
let vocabulary = Vocabulary::from(map);
assert!(vocabulary.eos_token_id.is_none());
assert!(vocabulary.tokens.is_empty());
Expand Down

0 comments on commit 21d6bb6

Please sign in to comment.