From bbce7d7850f46071be5f3aacf889081e4a0730e0 Mon Sep 17 00:00:00 2001 From: kc611 Date: Thu, 22 Aug 2024 14:00:27 +0530 Subject: [PATCH 1/3] Make PyO3 bindings optional --- Cargo.toml | 7 +- setup.py | 1 + src/json_schema/types.rs | 1 + src/lib.rs | 60 +--------- src/python_bindings/mod.rs | 219 +++++++++++++++++++++++++++++++++++++ src/regex.rs | 172 +---------------------------- 6 files changed, 234 insertions(+), 226 deletions(-) create mode 100644 src/python_bindings/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 05160120..05fec5a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,11 +5,11 @@ edition = "2021" [lib] name = "outlines_core_rs" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] anyhow = "1.0.86" -pyo3 = { version = "0.22.0", features = ["extension-module"] } +pyo3 = { version = "0.22.0", features = ["extension-module"], optional=true } regex = "1.10.6" serde-pyobject = "0.4.0" serde_json = { version ="1.0.125", features = ["preserve_order"] } @@ -20,3 +20,6 @@ lto = true codegen-units = 1 strip = true panic = 'abort' + +[features] +python-bindings = ["pyo3"] diff --git a/setup.py b/setup.py index 4c414e6b..d19e0715 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "outlines_core.fsm.outlines_core_rs", f"{CURRENT_DIR}/Cargo.toml", binding=Binding.PyO3, + features=["python-bindings"], rustc_flags=["--crate-type=cdylib"], ), ] diff --git a/src/json_schema/types.rs b/src/json_schema/types.rs index 02c748f6..aff5e53b 100644 --- a/src/json_schema/types.rs +++ b/src/json_schema/types.rs @@ -53,6 +53,7 @@ impl FormatType { } } + #[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option { match s { "date-time" => Some(FormatType::DateTime), diff --git a/src/lib.rs b/src/lib.rs index bd129f6d..0bc900d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,57 +1,5 @@ -mod json_schema; -mod regex; +pub mod json_schema; +pub mod regex; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -use pyo3::types::PyDict; -use pyo3::wrap_pyfunction; -use regex::_walk_fsm; -use regex::create_fsm_index_end_to_end; -use regex::get_token_transition_keys; -use regex::get_vocabulary_transition_keys; -use regex::state_scan_tokens; -use regex::FSMInfo; -use serde_json::Value; - -#[pymodule] -fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; - - m.add_class::()?; - - m.add("BOOLEAN", json_schema::BOOLEAN)?; - m.add("DATE", json_schema::DATE)?; - m.add("DATE_TIME", json_schema::DATE_TIME)?; - m.add("INTEGER", json_schema::INTEGER)?; - m.add("NULL", json_schema::NULL)?; - m.add("NUMBER", json_schema::NUMBER)?; - m.add("STRING", json_schema::STRING)?; - m.add("STRING_INNER", json_schema::STRING_INNER)?; - m.add("TIME", json_schema::TIME)?; - m.add("UUID", json_schema::UUID)?; - m.add("WHITESPACE", json_schema::WHITESPACE)?; - - m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?; - m.add_function(wrap_pyfunction!(to_regex, m)?)?; - - Ok(()) -} - -#[pyfunction(name = "build_regex_from_schema")] -#[pyo3(signature = (json, whitespace_pattern=None))] -pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult { - json_schema::build_regex_from_schema(&json, whitespace_pattern) - .map_err(|e| PyValueError::new_err(e.to_string())) -} - -#[pyfunction(name = "to_regex")] -#[pyo3(signature = (json, whitespace_pattern=None))] -pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { - let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); - json_schema::to_regex(&json_value, whitespace_pattern, &json_value) - .map_err(|e| PyValueError::new_err(e.to_string())) -} +#[cfg(feature = "python-bindings")] +mod python_bindings; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs new file mode 100644 index 00000000..81ec6956 --- /dev/null +++ b/src/python_bindings/mod.rs @@ -0,0 +1,219 @@ +use crate::json_schema; +use crate::regex::get_token_transition_keys_internal; +use crate::regex::get_vocabulary_transition_keys_internal; +use crate::regex::state_scan_tokens_internal; +use crate::regex::walk_fsm_internal; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::wrap_pyfunction; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; + +#[pyclass] +pub struct FSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + transitions: HashMap<(u32, u32), u32>, + #[pyo3(get)] + alphabet_anything_value: u32, + #[pyo3(get)] + alphabet_symbol_mapping: HashMap, +} + +#[pymethods] +impl FSMInfo { + #[new] + fn new( + initial: u32, + finals: HashSet, + transitions: HashMap<(u32, u32), u32>, + alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} + +#[pyfunction(name = "build_regex_from_schema")] +#[pyo3(signature = (json, whitespace_pattern=None))] +pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult { + json_schema::build_regex_from_schema(&json, whitespace_pattern) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction(name = "to_regex")] +#[pyo3(signature = (json, whitespace_pattern=None))] +pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { + let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); + json_schema::to_regex(&json_value, whitespace_pattern, &json_value) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction(name = "_walk_fsm")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" +)] +pub fn _walk_fsm( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + token_transition_keys: Vec, + start_state: u32, + full_match: bool, +) -> PyResult> { + Ok(walk_fsm_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &token_transition_keys, + start_state, + full_match, + )) +} + +#[pyfunction(name = "state_scan_tokens")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" +)] +pub fn state_scan_tokens( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + vocabulary: Vec<(String, Vec)>, + vocabulary_transition_keys: Vec>, + start_state: u32, +) -> PyResult> { + Ok(state_scan_tokens_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + )) +} + +#[pyfunction(name = "get_token_transition_keys")] +#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] +pub fn get_token_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + token_str: String, +) -> PyResult> { + Ok(get_token_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + )) +} + +#[pyfunction(name = "get_vocabulary_transition_keys")] +#[pyo3( + text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" +)] +pub fn get_vocabulary_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult>> { + Ok(get_vocabulary_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &vocabulary, + &frozen_tokens, + )) +} + +#[allow(clippy::too_many_arguments)] +#[pyfunction(name = "create_fsm_index_end_to_end")] +#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] +pub fn create_fsm_index_end_to_end<'py>( + py: Python<'py>, + fsm_info: &FSMInfo, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult> { + let states_to_token_subsets = PyDict::new_bound(py); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + + let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + &fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + &vocabulary, + &frozen_tokens, + ); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + + // TODO: Return Pydict directly at construction + let token_ids_end_states = state_scan_tokens_internal( + &fsm_info.transitions, + fsm_info.initial, + &fsm_info.finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + ); + + for (token_id, end_state) in token_ids_end_states { + if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { + existing_dict.set_item(token_id, end_state).unwrap(); + } else { + let new_dict = PyDict::new_bound(py); + new_dict.set_item(token_id, end_state).unwrap(); + states_to_token_subsets + .set_item(start_state, new_dict) + .unwrap(); + } + + if !seen.contains(&end_state) { + next_states.insert(end_state); + } + } + + seen.insert(start_state); + } + + Ok(states_to_token_subsets) +} + +#[pymodule] +fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + + m.add_class::()?; + + m.add("BOOLEAN", json_schema::BOOLEAN)?; + m.add("DATE", json_schema::DATE)?; + m.add("DATE_TIME", json_schema::DATE_TIME)?; + m.add("INTEGER", json_schema::INTEGER)?; + m.add("NULL", json_schema::NULL)?; + m.add("NUMBER", json_schema::NUMBER)?; + m.add("STRING", json_schema::STRING)?; + m.add("STRING_INNER", json_schema::STRING_INNER)?; + m.add("TIME", json_schema::TIME)?; + m.add("UUID", json_schema::UUID)?; + m.add("WHITESPACE", json_schema::WHITESPACE)?; + + m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?; + m.add_function(wrap_pyfunction!(to_regex, m)?)?; + + Ok(()) +} diff --git a/src/regex.rs b/src/regex.rs index df7d36f6..71ef2565 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,8 +1,6 @@ -use pyo3::prelude::*; - -use pyo3::types::PyDict; use std::collections::{HashMap, HashSet}; +#[allow(dead_code)] pub fn walk_fsm_internal( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, @@ -40,6 +38,7 @@ pub fn walk_fsm_internal( accepted_states } +#[allow(dead_code)] pub fn state_scan_tokens_internal( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, @@ -76,6 +75,7 @@ pub fn state_scan_tokens_internal( res } +#[allow(dead_code)] pub fn get_token_transition_keys_internal( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, @@ -109,6 +109,7 @@ pub fn get_token_transition_keys_internal( token_transition_keys } +#[allow(dead_code)] pub fn get_vocabulary_transition_keys_internal( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, @@ -144,168 +145,3 @@ pub fn get_vocabulary_transition_keys_internal( vocab_transition_keys } - -#[pyclass] -pub struct FSMInfo { - #[pyo3(get)] - initial: u32, - #[pyo3(get)] - finals: HashSet, - #[pyo3(get)] - transitions: HashMap<(u32, u32), u32>, - #[pyo3(get)] - alphabet_anything_value: u32, - #[pyo3(get)] - alphabet_symbol_mapping: HashMap, -} - -#[pymethods] -impl FSMInfo { - #[new] - fn new( - initial: u32, - finals: HashSet, - transitions: HashMap<(u32, u32), u32>, - alphabet_anything_value: u32, - alphabet_symbol_mapping: HashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } -} - -#[pyfunction(name = "_walk_fsm")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" -)] -pub fn _walk_fsm( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - token_transition_keys: Vec, - start_state: u32, - full_match: bool, -) -> PyResult> { - Ok(walk_fsm_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &token_transition_keys, - start_state, - full_match, - )) -} - -#[pyfunction(name = "state_scan_tokens")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" -)] -pub fn state_scan_tokens( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, - start_state: u32, -) -> PyResult> { - Ok(state_scan_tokens_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - )) -} - -#[pyfunction(name = "get_token_transition_keys")] -#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - token_str: String, -) -> PyResult> { - Ok(get_token_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - )) -} - -#[pyfunction(name = "get_vocabulary_transition_keys")] -#[pyo3( - text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" -)] -pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &vocabulary, - &frozen_tokens, - )) -} - -#[allow(clippy::too_many_arguments)] -#[pyfunction(name = "create_fsm_index_end_to_end")] -#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end<'py>( - py: Python<'py>, - fsm_info: &FSMInfo, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult> { - let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - &vocabulary, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { - existing_dict.set_item(token_id, end_state).unwrap(); - } else { - let new_dict = PyDict::new_bound(py); - new_dict.set_item(token_id, end_state).unwrap(); - states_to_token_subsets - .set_item(start_state, new_dict) - .unwrap(); - } - - if !seen.contains(&end_state) { - next_states.insert(end_state); - } - } - - seen.insert(start_state); - } - - Ok(states_to_token_subsets) -} From c54983b618ab4190e63b28adf2b235224a9a9c5a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 29 Aug 2024 12:32:55 -0500 Subject: [PATCH 2/3] Use _py suffix naming for Python bindings --- src/python_bindings/mod.rs | 51 ++++++++++++++++++++------------------ src/regex.rs | 12 ++++----- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 81ec6956..6aeec3bf 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,8 +1,8 @@ use crate::json_schema; -use crate::regex::get_token_transition_keys_internal; -use crate::regex::get_vocabulary_transition_keys_internal; -use crate::regex::state_scan_tokens_internal; -use crate::regex::walk_fsm_internal; +use crate::regex::get_token_transition_keys; +use crate::regex::get_vocabulary_transition_keys; +use crate::regex::state_scan_tokens; +use crate::regex::walk_fsm; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::PyDict; @@ -46,14 +46,17 @@ impl FSMInfo { #[pyfunction(name = "build_regex_from_schema")] #[pyo3(signature = (json, whitespace_pattern=None))] -pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult { +pub fn build_regex_from_schema_py( + json: String, + whitespace_pattern: Option<&str>, +) -> PyResult { json_schema::build_regex_from_schema(&json, whitespace_pattern) .map_err(|e| PyValueError::new_err(e.to_string())) } #[pyfunction(name = "to_regex")] #[pyo3(signature = (json, whitespace_pattern=None))] -pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { +pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); json_schema::to_regex(&json_value, whitespace_pattern, &json_value) .map_err(|e| PyValueError::new_err(e.to_string())) @@ -63,7 +66,7 @@ pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResu #[pyo3( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" )] -pub fn _walk_fsm( +pub fn walk_fsm_py( fsm_transitions: HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: HashSet, @@ -71,7 +74,7 @@ pub fn _walk_fsm( start_state: u32, full_match: bool, ) -> PyResult> { - Ok(walk_fsm_internal( + Ok(walk_fsm( &fsm_transitions, fsm_initial, &fsm_finals, @@ -85,7 +88,7 @@ pub fn _walk_fsm( #[pyo3( text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" )] -pub fn state_scan_tokens( +pub fn state_scan_tokens_py( fsm_transitions: HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: HashSet, @@ -93,7 +96,7 @@ pub fn state_scan_tokens( vocabulary_transition_keys: Vec>, start_state: u32, ) -> PyResult> { - Ok(state_scan_tokens_internal( + Ok(state_scan_tokens( &fsm_transitions, fsm_initial, &fsm_finals, @@ -105,12 +108,12 @@ pub fn state_scan_tokens( #[pyfunction(name = "get_token_transition_keys")] #[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys( +pub fn get_token_transition_keys_py( alphabet_symbol_mapping: HashMap, alphabet_anything_value: u32, token_str: String, ) -> PyResult> { - Ok(get_token_transition_keys_internal( + Ok(get_token_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, &token_str, @@ -121,13 +124,13 @@ pub fn get_token_transition_keys( #[pyo3( text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" )] -pub fn get_vocabulary_transition_keys( +pub fn get_vocabulary_transition_keys_py( alphabet_symbol_mapping: HashMap, alphabet_anything_value: u32, vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( + Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, &vocabulary, @@ -138,7 +141,7 @@ pub fn get_vocabulary_transition_keys( #[allow(clippy::too_many_arguments)] #[pyfunction(name = "create_fsm_index_end_to_end")] #[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end<'py>( +pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, fsm_info: &FSMInfo, vocabulary: Vec<(String, Vec)>, @@ -148,7 +151,7 @@ pub fn create_fsm_index_end_to_end<'py>( let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + let vocabulary_transition_keys = get_vocabulary_transition_keys( &fsm_info.alphabet_symbol_mapping, fsm_info.alphabet_anything_value, &vocabulary, @@ -159,7 +162,7 @@ pub fn create_fsm_index_end_to_end<'py>( next_states.remove(&start_state); // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( + let token_ids_end_states = state_scan_tokens( &fsm_info.transitions, fsm_info.initial, &fsm_info.finals, @@ -192,11 +195,11 @@ pub fn create_fsm_index_end_to_end<'py>( #[pymodule] fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; m.add_class::()?; @@ -212,8 +215,8 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("UUID", json_schema::UUID)?; m.add("WHITESPACE", json_schema::WHITESPACE)?; - m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?; - m.add_function(wrap_pyfunction!(to_regex, m)?)?; + m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?; + m.add_function(wrap_pyfunction!(to_regex_py, m)?)?; Ok(()) } diff --git a/src/regex.rs b/src/regex.rs index 71ef2565..e9cb68d2 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, HashSet}; #[allow(dead_code)] -pub fn walk_fsm_internal( +pub fn walk_fsm( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, fsm_finals: &HashSet, @@ -39,7 +39,7 @@ pub fn walk_fsm_internal( } #[allow(dead_code)] -pub fn state_scan_tokens_internal( +pub fn state_scan_tokens( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: &HashSet, @@ -54,7 +54,7 @@ pub fn state_scan_tokens_internal( { let token_ids: Vec = vocab_item.1.clone(); - let state_seq = walk_fsm_internal( + let state_seq = walk_fsm( fsm_transitions, fsm_initial, fsm_finals, @@ -76,7 +76,7 @@ pub fn state_scan_tokens_internal( } #[allow(dead_code)] -pub fn get_token_transition_keys_internal( +pub fn get_token_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, token_str: &str, @@ -110,7 +110,7 @@ pub fn get_token_transition_keys_internal( } #[allow(dead_code)] -pub fn get_vocabulary_transition_keys_internal( +pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, vocabulary: &[(String, Vec)], @@ -133,7 +133,7 @@ pub fn get_vocabulary_transition_keys_internal( .unwrap_or(&alphabet_anything_value), ) } else { - token_transition_keys = get_token_transition_keys_internal( + token_transition_keys = get_token_transition_keys( alphabet_symbol_mapping, alphabet_anything_value, &token_str, From b8dfa3b135a79fda879d81dae8f763cf376fb09a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 29 Aug 2024 12:37:15 -0500 Subject: [PATCH 3/3] Remove unnecessary linter allowances --- src/python_bindings/mod.rs | 1 - src/regex.rs | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 6aeec3bf..22bebe62 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -138,7 +138,6 @@ pub fn get_vocabulary_transition_keys_py( )) } -#[allow(clippy::too_many_arguments)] #[pyfunction(name = "create_fsm_index_end_to_end")] #[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] pub fn create_fsm_index_end_to_end_py<'py>( diff --git a/src/regex.rs b/src/regex.rs index e9cb68d2..1db920ac 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, HashSet}; -#[allow(dead_code)] pub fn walk_fsm( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, @@ -38,7 +37,6 @@ pub fn walk_fsm( accepted_states } -#[allow(dead_code)] pub fn state_scan_tokens( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, @@ -75,7 +73,6 @@ pub fn state_scan_tokens( res } -#[allow(dead_code)] pub fn get_token_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, @@ -109,7 +106,6 @@ pub fn get_token_transition_keys( token_transition_keys } -#[allow(dead_code)] pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32,