Skip to content

Commit

Permalink
Use _py suffix naming for Python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 29, 2024
1 parent bbce7d7 commit c54983b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
51 changes: 27 additions & 24 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::json_schema;
use crate::regex::get_token_transition_keys_internal;
use crate::regex::get_vocabulary_transition_keys_internal;
use crate::regex::state_scan_tokens_internal;
use crate::regex::walk_fsm_internal;
use crate::regex::get_token_transition_keys;
use crate::regex::get_vocabulary_transition_keys;
use crate::regex::state_scan_tokens;
use crate::regex::walk_fsm;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
Expand Down Expand Up @@ -46,14 +46,17 @@ impl FSMInfo {

#[pyfunction(name = "build_regex_from_schema")]
#[pyo3(signature = (json, whitespace_pattern=None))]
pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult<String> {
pub fn build_regex_from_schema_py(
json: String,
whitespace_pattern: Option<&str>,
) -> PyResult<String> {
json_schema::build_regex_from_schema(&json, whitespace_pattern)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

#[pyfunction(name = "to_regex")]
#[pyo3(signature = (json, whitespace_pattern=None))]
pub fn to_regex(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyResult<String> {
pub fn to_regex_py(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyResult<String> {
let json_value: Value = serde_pyobject::from_pyobject(json).unwrap();
json_schema::to_regex(&json_value, whitespace_pattern, &json_value)
.map_err(|e| PyValueError::new_err(e.to_string()))
Expand All @@ -63,15 +66,15 @@ pub fn to_regex(json: Bound<PyDict>, whitespace_pattern: Option<&str>) -> PyResu
#[pyo3(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)"
)]
pub fn _walk_fsm(
pub fn walk_fsm_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
token_transition_keys: Vec<u32>,
start_state: u32,
full_match: bool,
) -> PyResult<Vec<u32>> {
Ok(walk_fsm_internal(
Ok(walk_fsm(
&fsm_transitions,
fsm_initial,
&fsm_finals,
Expand All @@ -85,15 +88,15 @@ pub fn _walk_fsm(
#[pyo3(
text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)"
)]
pub fn state_scan_tokens(
pub fn state_scan_tokens_py(
fsm_transitions: HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: HashSet<u32>,
vocabulary: Vec<(String, Vec<u32>)>,
vocabulary_transition_keys: Vec<Vec<u32>>,
start_state: u32,
) -> PyResult<HashSet<(u32, u32)>> {
Ok(state_scan_tokens_internal(
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
&fsm_finals,
Expand All @@ -105,12 +108,12 @@ pub fn state_scan_tokens(

#[pyfunction(name = "get_token_transition_keys")]
#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")]
pub fn get_token_transition_keys(
pub fn get_token_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
token_str: String,
) -> PyResult<Vec<u32>> {
Ok(get_token_transition_keys_internal(
Ok(get_token_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
&token_str,
Expand All @@ -121,13 +124,13 @@ pub fn get_token_transition_keys(
#[pyo3(
text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)"
)]
pub fn get_vocabulary_transition_keys(
pub fn get_vocabulary_transition_keys_py(
alphabet_symbol_mapping: HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: Vec<(String, Vec<u32>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<u32>>> {
Ok(get_vocabulary_transition_keys_internal(
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
&vocabulary,
Expand All @@ -138,7 +141,7 @@ pub fn get_vocabulary_transition_keys(
#[allow(clippy::too_many_arguments)]
#[pyfunction(name = "create_fsm_index_end_to_end")]
#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")]
pub fn create_fsm_index_end_to_end<'py>(
pub fn create_fsm_index_end_to_end_py<'py>(
py: Python<'py>,
fsm_info: &FSMInfo,
vocabulary: Vec<(String, Vec<u32>)>,
Expand All @@ -148,7 +151,7 @@ pub fn create_fsm_index_end_to_end<'py>(
let mut seen: HashSet<u32> = HashSet::new();
let mut next_states: HashSet<u32> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys_internal(
let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
&vocabulary,
Expand All @@ -159,7 +162,7 @@ pub fn create_fsm_index_end_to_end<'py>(
next_states.remove(&start_state);

// TODO: Return Pydict directly at construction
let token_ids_end_states = state_scan_tokens_internal(
let token_ids_end_states = state_scan_tokens(
&fsm_info.transitions,
fsm_info.initial,
&fsm_info.finals,
Expand Down Expand Up @@ -192,11 +195,11 @@ pub fn create_fsm_index_end_to_end<'py>(

#[pymodule]
fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?;
m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?;
m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?;
m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?;
m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?;
m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?;
m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?;
m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?;
m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?;
m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?;

m.add_class::<FSMInfo>()?;

Expand All @@ -212,8 +215,8 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("UUID", json_schema::UUID)?;
m.add("WHITESPACE", json_schema::WHITESPACE)?;

m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?;
m.add_function(wrap_pyfunction!(to_regex, m)?)?;
m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?;
m.add_function(wrap_pyfunction!(to_regex_py, m)?)?;

Ok(())
}
12 changes: 6 additions & 6 deletions src/regex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{HashMap, HashSet};

#[allow(dead_code)]
pub fn walk_fsm_internal(
pub fn walk_fsm(
fsm_transitions: &HashMap<(u32, u32), u32>,
_fsm_initial: u32,
fsm_finals: &HashSet<u32>,
Expand Down Expand Up @@ -39,7 +39,7 @@ pub fn walk_fsm_internal(
}

#[allow(dead_code)]
pub fn state_scan_tokens_internal(
pub fn state_scan_tokens(
fsm_transitions: &HashMap<(u32, u32), u32>,
fsm_initial: u32,
fsm_finals: &HashSet<u32>,
Expand All @@ -54,7 +54,7 @@ pub fn state_scan_tokens_internal(
{
let token_ids: Vec<u32> = vocab_item.1.clone();

let state_seq = walk_fsm_internal(
let state_seq = walk_fsm(
fsm_transitions,
fsm_initial,
fsm_finals,
Expand All @@ -76,7 +76,7 @@ pub fn state_scan_tokens_internal(
}

#[allow(dead_code)]
pub fn get_token_transition_keys_internal(
pub fn get_token_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
token_str: &str,
Expand Down Expand Up @@ -110,7 +110,7 @@ pub fn get_token_transition_keys_internal(
}

#[allow(dead_code)]
pub fn get_vocabulary_transition_keys_internal(
pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, u32>,
alphabet_anything_value: u32,
vocabulary: &[(String, Vec<u32>)],
Expand All @@ -133,7 +133,7 @@ pub fn get_vocabulary_transition_keys_internal(
.unwrap_or(&alphabet_anything_value),
)
} else {
token_transition_keys = get_token_transition_keys_internal(
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping,
alphabet_anything_value,
&token_str,
Expand Down

0 comments on commit c54983b

Please sign in to comment.