Skip to content

Commit

Permalink
Use type aliases to improve readability
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 20, 2024
1 parent d85a8f9 commit e63edab
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 55 deletions.
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ pub mod regex;

#[cfg(feature = "python-bindings")]
mod python_bindings;

mod primitives;

pub use crate::primitives::{State, TokenId, TransitionKey};
8 changes: 8 additions & 0 deletions src/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/// Interegular transition key.
pub type TransitionKey = u32;

/// Token identifier.
pub type TokenId = u32;

/// Interegular state.
pub type State = u32;
67 changes: 34 additions & 33 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::regex::get_token_transition_keys;
use crate::regex::get_vocabulary_transition_keys;
use crate::regex::state_scan_tokens;
use crate::regex::walk_fsm;
use crate::*;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand All @@ -13,26 +14,26 @@ use std::collections::{HashMap, HashSet};
#[pyclass]
pub struct FSMInfo {
#[pyo3(get)]
initial: u32,
initial: State,
#[pyo3(get)]
finals: HashSet<u32>,
finals: HashSet<State>,
#[pyo3(get)]
transitions: HashMap<(u32, u32), u32>,
transitions: HashMap<(State, TransitionKey), State>,
#[pyo3(get)]
alphabet_anything_value: u32,
alphabet_anything_value: TransitionKey,
#[pyo3(get)]
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
}

#[pymethods]
impl FSMInfo {
#[new]
fn new(
initial: u32,
finals: HashSet<u32>,
transitions: HashMap<(u32, u32), u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: HashMap<String, u32>,
initial: State,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
) -> Self {
Self {
initial,
Expand Down Expand Up @@ -67,13 +68,13 @@ pub fn to_regex_py(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyR
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)"
)]
pub fn walk_fsm_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
token_transition_keys: Vec<u32>,
start_state: u32,
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
token_transition_keys: Vec<TransitionKey>,
start_state: State,
full_match: bool,
) -> PyResult<Vec<u32>> {
) -> PyResult<Vec<State>> {
Ok(walk_fsm(
&fsm_transitions,
fsm_initial,
Expand All @@ -89,13 +90,13 @@ pub fn walk_fsm_py(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)"
)]
pub fn state_scan_tokens_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
vocabulary: Vec<(String, Vec<u32>)>,
vocabulary_transition_keys: Vec<Vec<u32>>,
start_state: u32,
) -> PyResult<HashSet<(u32, u32)>> {
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary_transition_keys: Vec<Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand All @@ -109,10 +110,10 @@ pub fn state_scan_tokens_py(
#[pyfunction(name = "get_token_transition_keys")]
#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")]
pub fn get_token_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: String,
) -> PyResult<Vec<u32>> {
) -> PyResult<Vec<TransitionKey>> {
Ok(get_token_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -125,11 +126,11 @@ pub fn get_token_transition_keys_py(
text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)"
)]
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: Vec<(String, Vec<u32>)>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<u32>>> {
) -> PyResult<Vec<Vec<TransitionKey>>> {
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -143,12 +144,12 @@ pub fn get_vocabulary_transition_keys_py(
pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &FSMInfo,
vocabulary: Vec<(String, Vec<u32>)>,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<u32> = HashSet::new();
let mut next_states: HashSet<u32> = HashSet::from_iter(vec![fsm_info.initial]);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down
45 changes: 23 additions & 22 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::*;
use std::collections::{HashMap, HashSet};

pub fn walk_fsm(
fsm_transitions: &HashMap<(u32, u32), u32>,
_fsm_initial: u32,
fsm_finals: &HashSet<u32>,
token_transition_keys: &[u32],
start_state: u32,
fsm_transitions: &HashMap<(State, TransitionKey), State>,
_fsm_initial: State,
fsm_finals: &HashSet<State>,
token_transition_keys: &[TransitionKey],
start_state: State,
full_match: bool,
) -> Vec<u32> {
) -> Vec<State> {
let mut state = start_state;
let mut accepted_states = Vec::new();
let mut last_final_idx = 0;
Expand Down Expand Up @@ -38,19 +39,19 @@ pub fn walk_fsm(
}

pub fn state_scan_tokens(
fsm_transitions: &HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: &HashSet<u32>,
vocabulary: &[(String, Vec<u32>)],
vocabulary_transition_keys: &[Vec<u32>],
start_state: u32,
) -> HashSet<(u32, u32)> {
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &HashSet<State>,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary_transition_keys: &[Vec<TransitionKey>],
start_state: State,
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::new();

for (vocab_item, token_transition_keys) in
vocabulary.iter().zip(vocabulary_transition_keys.iter())
{
let token_ids: Vec<u32> = vocab_item.1.clone();
let token_ids: Vec<TokenId> = vocab_item.1.clone();

let state_seq = walk_fsm(
fsm_transitions,
Expand All @@ -74,10 +75,10 @@ pub fn state_scan_tokens(
}

pub fn get_token_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: &str,
) -> Vec<u32> {
) -> Vec<TransitionKey> {
let mut token_transition_keys = Vec::new();
let mut i = 0;
let chars: Vec<char> = token_str.chars().collect();
Expand Down Expand Up @@ -107,12 +108,12 @@ pub fn get_token_transition_keys(
}

pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: &[(String, Vec<u32>)],
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &[(String, Vec<TokenId>)],
frozen_tokens: &HashSet<String>,
) -> Vec<Vec<u32>> {
let mut vocab_transition_keys: Vec<Vec<u32>> = Vec::new();
) -> Vec<Vec<TransitionKey>> {
let mut vocab_transition_keys: Vec<Vec<TransitionKey>> = Vec::new();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand Down

0 comments on commit e63edab

Please sign in to comment.