Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use type aliases to improve readability #29

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ pub mod regex;

#[cfg(feature = "python-bindings")]
mod python_bindings;

mod primitives;

pub use crate::primitives::{State, TokenId, TransitionKey};
8 changes: 8 additions & 0 deletions src/primitives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/// Interegular transition key.
pub type TransitionKey = u32;

/// Token identifier.
pub type TokenId = u32;

/// Interegular state.
pub type State = u32;
67 changes: 34 additions & 33 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::regex::get_token_transition_keys;
use crate::regex::get_vocabulary_transition_keys;
use crate::regex::state_scan_tokens;
use crate::regex::walk_fsm;
use crate::*;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand All @@ -13,26 +14,26 @@ use std::collections::{HashMap, HashSet};
#[pyclass]
pub struct FSMInfo {
#[pyo3(get)]
initial: u32,
initial: State,
#[pyo3(get)]
finals: HashSet<u32>,
finals: HashSet<State>,
#[pyo3(get)]
transitions: HashMap<(u32, u32), u32>,
transitions: HashMap<(State, TransitionKey), State>,
#[pyo3(get)]
alphabet_anything_value: u32,
alphabet_anything_value: TransitionKey,
#[pyo3(get)]
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
}

#[pymethods]
impl FSMInfo {
#[new]
fn new(
initial: u32,
finals: HashSet<u32>,
transitions: HashMap<(u32, u32), u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: HashMap<String, u32>,
initial: State,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
) -> Self {
Self {
initial,
Expand Down Expand Up @@ -67,13 +68,13 @@ pub fn to_regex_py(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyR
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)"
)]
pub fn walk_fsm_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
token_transition_keys: Vec<u32>,
start_state: u32,
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
token_transition_keys: Vec<TransitionKey>,
start_state: State,
full_match: bool,
) -> PyResult<Vec<u32>> {
) -> PyResult<Vec<State>> {
Ok(walk_fsm(
&fsm_transitions,
fsm_initial,
Expand All @@ -89,13 +90,13 @@ pub fn walk_fsm_py(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)"
)]
pub fn state_scan_tokens_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
vocabulary: Vec<(String, Vec<u32>)>,
vocabulary_transition_keys: Vec<Vec<u32>>,
start_state: u32,
) -> PyResult<HashSet<(u32, u32)>> {
fsm_transitions: HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: HashSet<State>,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary_transition_keys: Vec<Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand All @@ -109,10 +110,10 @@ pub fn state_scan_tokens_py(
#[pyfunction(name = "get_token_transition_keys")]
#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")]
pub fn get_token_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: String,
) -> PyResult<Vec<u32>> {
) -> PyResult<Vec<TransitionKey>> {
Ok(get_token_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -125,11 +126,11 @@ pub fn get_token_transition_keys_py(
text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)"
)]
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: Vec<(String, Vec<u32>)>,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<u32>>> {
) -> PyResult<Vec<Vec<TransitionKey>>> {
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -143,12 +144,12 @@ pub fn get_vocabulary_transition_keys_py(
pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &FSMInfo,
vocabulary: Vec<(String, Vec<u32>)>,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<u32> = HashSet::new();
let mut next_states: HashSet<u32> = HashSet::from_iter(vec![fsm_info.initial]);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
Expand Down
45 changes: 23 additions & 22 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::*;
use std::collections::{HashMap, HashSet};

pub fn walk_fsm(
fsm_transitions: &HashMap<(u32, u32), u32>,
_fsm_initial: u32,
fsm_finals: &HashSet<u32>,
token_transition_keys: &[u32],
start_state: u32,
fsm_transitions: &HashMap<(State, TransitionKey), State>,
_fsm_initial: State,
fsm_finals: &HashSet<State>,
token_transition_keys: &[TransitionKey],
start_state: State,
full_match: bool,
) -> Vec<u32> {
) -> Vec<State> {
let mut state = start_state;
let mut accepted_states = Vec::new();
let mut last_final_idx = 0;
Expand Down Expand Up @@ -38,19 +39,19 @@ pub fn walk_fsm(
}

pub fn state_scan_tokens(
fsm_transitions: &HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: &HashSet<u32>,
vocabulary: &[(String, Vec<u32>)],
vocabulary_transition_keys: &[Vec<u32>],
start_state: u32,
) -> HashSet<(u32, u32)> {
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &HashSet<State>,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary_transition_keys: &[Vec<TransitionKey>],
start_state: State,
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::new();

for (vocab_item, token_transition_keys) in
vocabulary.iter().zip(vocabulary_transition_keys.iter())
{
let token_ids: Vec<u32> = vocab_item.1.clone();
let token_ids: Vec<TokenId> = vocab_item.1.clone();

let state_seq = walk_fsm(
fsm_transitions,
Expand All @@ -74,10 +75,10 @@ pub fn state_scan_tokens(
}

pub fn get_token_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
token_str: &str,
) -> Vec<u32> {
) -> Vec<TransitionKey> {
let mut token_transition_keys = Vec::new();
let mut i = 0;
let chars: Vec<char> = token_str.chars().collect();
Expand Down Expand Up @@ -107,12 +108,12 @@ pub fn get_token_transition_keys(
}

pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: &[(String, Vec<u32>)],
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &[(String, Vec<TokenId>)],
frozen_tokens: &HashSet<String>,
) -> Vec<Vec<u32>> {
let mut vocab_transition_keys: Vec<Vec<u32>> = Vec::new();
) -> Vec<Vec<TransitionKey>> {
let mut vocab_transition_keys: Vec<Vec<TransitionKey>> = Vec::new();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand Down
Loading