diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 81ec6956..6aeec3bf 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,8 +1,8 @@ use crate::json_schema; -use crate::regex::get_token_transition_keys_internal; -use crate::regex::get_vocabulary_transition_keys_internal; -use crate::regex::state_scan_tokens_internal; -use crate::regex::walk_fsm_internal; +use crate::regex::get_token_transition_keys; +use crate::regex::get_vocabulary_transition_keys; +use crate::regex::state_scan_tokens; +use crate::regex::walk_fsm; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -46,14 +46,17 @@ impl FSMInfo { #[pyfunction(name = "build_regex_from_schema")] #[pyo3(signature = (json, whitespace_pattern=None))] -pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult { +pub fn build_regex_from_schema_py( + json: String, + whitespace_pattern: Option<&str>, +) -> PyResult { json_schema::build_regex_from_schema(&json, whitespace_pattern) .map_err(|e| PyValueError::new_err(e.to_string())) } #[pyfunction(name = "to_regex")] #[pyo3(signature = (json, whitespace_pattern=None))] -pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { +pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); json_schema::to_regex(&json_value, whitespace_pattern, &json_value) .map_err(|e| PyValueError::new_err(e.to_string())) @@ -63,7 +66,7 @@ pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResu #[pyo3( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] -pub fn _walk_fsm( +pub fn walk_fsm_py( fsm_transitions: HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: HashSet, @@ -71,7 +74,7 @@ pub fn _walk_fsm( start_state: u32, full_match: bool, ) -> PyResult> { - Ok(walk_fsm_internal( + Ok(walk_fsm( &fsm_transitions, fsm_initial, &fsm_finals, @@ -85,7 +88,7 @@ pub fn _walk_fsm( #[pyo3( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] -pub fn state_scan_tokens( +pub fn state_scan_tokens_py( fsm_transitions: HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: HashSet, @@ -93,7 +96,7 @@ pub fn state_scan_tokens( vocabulary_transition_keys: Vec>, start_state: u32, ) -> PyResult> { - Ok(state_scan_tokens_internal( + Ok(state_scan_tokens( &fsm_transitions, fsm_initial, &fsm_finals, @@ -105,12 +108,12 @@ pub fn state_scan_tokens( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys( +pub fn get_token_transition_keys_py( alphabet_symbol_mapping: HashMap, alphabet_anything_value: u32, token_str: String, ) -> PyResult> { - Ok(get_token_transition_keys_internal( + Ok(get_token_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, &token_str, @@ -121,13 +124,13 @@ pub fn get_token_transition_keys( #[pyo3( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] -pub fn get_vocabulary_transition_keys( +pub fn get_vocabulary_transition_keys_py( alphabet_symbol_mapping: HashMap, alphabet_anything_value: u32, vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( + Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, &vocabulary, @@ -138,7 +141,7 @@ pub fn get_vocabulary_transition_keys( #[allow(clippy::too_many_arguments)] #[pyfunction(name = "create_fsm_index_end_to_end")] #[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end<'py>( +pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &FSMInfo, vocabulary: Vec<(String, Vec)>, @@ -148,7 +151,7 @@ pub fn create_fsm_index_end_to_end<'py>( let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, &vocabulary, @@ -159,7 +162,7 @@ pub fn create_fsm_index_end_to_end<'py>( next_states.remove(&start_state); // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( + let token_ids_end_states = state_scan_tokens( &fsm_info.transitions, fsm_info.initial, &fsm_info.finals, @@ -192,11 +195,11 @@ pub fn create_fsm_index_end_to_end<'py>( #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; m.add_class::()?; @@ -212,8 +215,8 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("UUID", json_schema::UUID)?; m.add("WHITESPACE", json_schema::WHITESPACE)?; - m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?; - m.add_function(wrap_pyfunction!(to_regex, m)?)?; + m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?; + m.add_function(wrap_pyfunction!(to_regex_py, m)?)?; Ok(()) } diff --git a/src/regex.rs b/src/regex.rs index 71ef2565..e9cb68d2 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; #[allow(dead_code)] -pub fn walk_fsm_internal( +pub fn walk_fsm( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, fsm_finals: &HashSet, @@ -39,7 +39,7 @@ pub fn walk_fsm_internal( } #[allow(dead_code)] -pub fn state_scan_tokens_internal( +pub fn state_scan_tokens( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: &HashSet, @@ -54,7 +54,7 @@ pub fn state_scan_tokens_internal( { let token_ids: Vec = vocab_item.1.clone(); - let state_seq = walk_fsm_internal( + let state_seq = walk_fsm( fsm_transitions, fsm_initial, fsm_finals, @@ -76,7 +76,7 @@ pub fn state_scan_tokens_internal( } #[allow(dead_code)] -pub fn get_token_transition_keys_internal( +pub fn get_token_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, token_str: &str, @@ -110,7 +110,7 @@ pub fn get_token_transition_keys_internal( } #[allow(dead_code)] -pub fn get_vocabulary_transition_keys_internal( +pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, vocabulary: &[(String, Vec)], @@ -133,7 +133,7 @@ pub fn get_vocabulary_transition_keys_internal( .unwrap_or(&alphabet_anything_value), ) } else { - token_transition_keys = get_token_transition_keys_internal( + token_transition_keys = get_token_transition_keys( alphabet_symbol_mapping, alphabet_anything_value, &token_str,