diff --git a/Cargo.toml b/Cargo.toml index d23e55f2..d8af9eeb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ name = "outlines_core_rs" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.22.0", features = ["extension-module"] } +pyo3 = { version = "0.22.0", features = ["extension-module"], optional=true } [profile.release] opt-level = 3 @@ -16,3 +16,6 @@ lto = true codegen-units = 1 strip = true panic = 'abort' + +[features] +python-bindings = ["pyo3"] diff --git a/setup.py b/setup.py index 4c414e6b..d19e0715 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "outlines_core.fsm.outlines_core_rs", f"{CURRENT_DIR}/Cargo.toml", binding=Binding.PyO3, + features=["python-bindings"], rustc_flags=["--crate-type=cdylib"], ), ] diff --git a/src/lib.rs b/src/lib.rs index 534b0bb7..b3d73dd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,4 @@ mod regex; -use pyo3::prelude::*; -use pyo3::wrap_pyfunction; -use regex::_walk_fsm; -use regex::create_fsm_index_end_to_end; -use regex::get_token_transition_keys; -use regex::get_vocabulary_transition_keys; -use regex::state_scan_tokens; -use regex::FSMInfo; - -#[pymodule] -fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; - - m.add_class::()?; - - Ok(()) -} +#[cfg(feature = "python-bindings")] +mod python_bindings; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs new file mode 100644 index 00000000..ab65e766 --- /dev/null +++ b/src/python_bindings/mod.rs @@ -0,0 +1,186 @@ +use crate::regex::get_token_transition_keys_internal; +use crate::regex::get_vocabulary_transition_keys_internal; +use crate::regex::state_scan_tokens_internal; +use crate::regex::walk_fsm_internal; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::wrap_pyfunction; +use std::collections::{HashMap, HashSet}; + +#[pyclass] +pub struct FSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + transitions: HashMap<(u32, u32), u32>, + #[pyo3(get)] + alphabet_anything_value: u32, + #[pyo3(get)] + alphabet_symbol_mapping: HashMap, +} + +#[pymethods] +impl FSMInfo { + #[new] + fn new( + initial: u32, + finals: HashSet, + transitions: HashMap<(u32, u32), u32>, + alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} + +#[pyfunction(name = "_walk_fsm")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" +)] +pub fn _walk_fsm( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + token_transition_keys: Vec, + start_state: u32, + full_match: bool, +) -> PyResult> { + Ok(walk_fsm_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &token_transition_keys, + start_state, + full_match, + )) +} + +#[pyfunction(name = "state_scan_tokens")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" +)] +pub fn state_scan_tokens( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + vocabulary: Vec<(String, Vec)>, + vocabulary_transition_keys: Vec>, + start_state: u32, +) -> PyResult> { + Ok(state_scan_tokens_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + )) +} + +#[pyfunction(name = "get_token_transition_keys")] +#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] +pub fn get_token_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + token_str: String, +) -> PyResult> { + Ok(get_token_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + )) +} + +#[pyfunction(name = "get_vocabulary_transition_keys")] +#[pyo3( + text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" +)] +pub fn get_vocabulary_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult>> { + Ok(get_vocabulary_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &vocabulary, + &frozen_tokens, + )) +} + +#[allow(clippy::too_many_arguments)] +#[pyfunction(name = "create_fsm_index_end_to_end")] +#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] +pub fn create_fsm_index_end_to_end<'py>( + py: Python<'py>, + fsm_info: &FSMInfo, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult> { + let states_to_token_subsets = PyDict::new_bound(py); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + + let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + &fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + &vocabulary, + &frozen_tokens, + ); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + + // TODO: Return Pydict directly at construction + let token_ids_end_states = state_scan_tokens_internal( + &fsm_info.transitions, + fsm_info.initial, + &fsm_info.finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + ); + + for (token_id, end_state) in token_ids_end_states { + if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { + existing_dict.set_item(token_id, end_state).unwrap(); + } else { + let new_dict = PyDict::new_bound(py); + new_dict.set_item(token_id, end_state).unwrap(); + states_to_token_subsets + .set_item(start_state, new_dict) + .unwrap(); + } + + if !seen.contains(&end_state) { + next_states.insert(end_state); + } + } + + seen.insert(start_state); + } + + Ok(states_to_token_subsets) +} + +#[pymodule] +fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + + m.add_class::()?; + + Ok(()) +} diff --git a/src/regex.rs b/src/regex.rs index df7d36f6..71ef2565 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,8 +1,6 @@ -use pyo3::prelude::*; - -use pyo3::types::PyDict; use std::collections::{HashMap, HashSet}; +#[allow(dead_code)] pub fn walk_fsm_internal( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, @@ -40,6 +38,7 @@ pub fn walk_fsm_internal( accepted_states } +#[allow(dead_code)] pub fn state_scan_tokens_internal( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, @@ -76,6 +75,7 @@ pub fn state_scan_tokens_internal( res } +#[allow(dead_code)] pub fn get_token_transition_keys_internal( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, @@ -109,6 +109,7 @@ pub fn get_token_transition_keys_internal( token_transition_keys } +#[allow(dead_code)] pub fn get_vocabulary_transition_keys_internal( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, @@ -144,168 +145,3 @@ pub fn get_vocabulary_transition_keys_internal( vocab_transition_keys } - -#[pyclass] -pub struct FSMInfo { - #[pyo3(get)] - initial: u32, - #[pyo3(get)] - finals: HashSet, - #[pyo3(get)] - transitions: HashMap<(u32, u32), u32>, - #[pyo3(get)] - alphabet_anything_value: u32, - #[pyo3(get)] - alphabet_symbol_mapping: HashMap, -} - -#[pymethods] -impl FSMInfo { - #[new] - fn new( - initial: u32, - finals: HashSet, - transitions: HashMap<(u32, u32), u32>, - alphabet_anything_value: u32, - alphabet_symbol_mapping: HashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } -} - -#[pyfunction(name = "_walk_fsm")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" -)] -pub fn _walk_fsm( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - token_transition_keys: Vec, - start_state: u32, - full_match: bool, -) -> PyResult> { - Ok(walk_fsm_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &token_transition_keys, - start_state, - full_match, - )) -} - -#[pyfunction(name = "state_scan_tokens")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" -)] -pub fn state_scan_tokens( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, - start_state: u32, -) -> PyResult> { - Ok(state_scan_tokens_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - )) -} - -#[pyfunction(name = "get_token_transition_keys")] -#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - token_str: String, -) -> PyResult> { - Ok(get_token_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - )) -} - -#[pyfunction(name = "get_vocabulary_transition_keys")] -#[pyo3( - text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" -)] -pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &vocabulary, - &frozen_tokens, - )) -} - -#[allow(clippy::too_many_arguments)] -#[pyfunction(name = "create_fsm_index_end_to_end")] -#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end<'py>( - py: Python<'py>, - fsm_info: &FSMInfo, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult> { - let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - &vocabulary, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { - existing_dict.set_item(token_id, end_state).unwrap(); - } else { - let new_dict = PyDict::new_bound(py); - new_dict.set_item(token_id, end_state).unwrap(); - states_to_token_subsets - .set_item(start_state, new_dict) - .unwrap(); - } - - if !seen.contains(&end_state) { - next_states.insert(end_state); - } - } - - seen.insert(start_state); - } - - Ok(states_to_token_subsets) -}