diff --git a/function_sampler/fsm/fsm_utils/src/tokenizer_index.rs b/function_sampler/fsm/fsm_utils/src/tokenizer_index.rs index 6c05be2..b502dfe 100644 --- a/function_sampler/fsm/fsm_utils/src/tokenizer_index.rs +++ b/function_sampler/fsm/fsm_utils/src/tokenizer_index.rs @@ -53,53 +53,35 @@ fn walk_fsm( ) -> Vec { let mut state = start_state; let mut accepted_states = Vec::new(); - let mut last_final_idx: Option = None; + let mut last_final_idx = 0; - let mut current_pos = 0; - let input_chars: Vec = input_string.chars().collect(); + for (i, symbol) in input_string.chars().enumerate() { + let trans_key = fsm_info + .alphabet_symbol_mapping + .get(&symbol.to_string()) + .unwrap_or(&fsm_info.alphabet_anything_value); - while current_pos < input_chars.len() { - let mut found = false; + let new_state = fsm_info.transitions.get(&(state, *trans_key)); - // Attempt to match longer substrings first, ensuring multi-character sequences are prioritized - for len in (1..=input_chars.len() - current_pos).rev() { - let possible_match: String = - input_chars[current_pos..current_pos + len].iter().collect(); - - if let Some(&trans_key) = fsm_info.alphabet_symbol_mapping.get(&possible_match) { - if let Some(&new_state) = fsm_info.transitions.get(&(state, trans_key)) { - state = new_state; - if fsm_info.finals.contains(&state) { - last_final_idx = Some(accepted_states.len() + 1); - } - accepted_states.push(state); - current_pos += len; // Move past the matched substring - found = true; - break; - } + if let Some(&new_state) = new_state { + state = new_state; + if fsm_info.finals.contains(&state) { + last_final_idx = i + 1; // Store the index of the last final state encountered } - } - - if !found { - if !full_match && last_final_idx.is_some() { - // Non-full match and we've previously encountered a final state - return accepted_states - .into_iter() - .take(last_final_idx.unwrap()) - .collect(); - } else { - // No match found, or a full match is required - return vec![]; + accepted_states.push(state); + } else { + if !full_match && last_final_idx > 0 { + return accepted_states.into_iter().take(last_final_idx).collect(); } + return Vec::new(); } } - // Full match checks - if full_match && last_final_idx.map_or(true, |idx| idx != accepted_states.len()) { - return vec![]; // Full match required but last character didn't result in a final state + if full_match && last_final_idx - 1 != input_string.chars().count() - 1 { + Vec::new() // If full match is required and last final state is not at the end, return empty + } else { + accepted_states } - - accepted_states } /// This function scans a set of tokens against an FSM to determine the resulting states from a given start state. diff --git a/function_sampler/fsm/fsm_utils/src/types.rs b/function_sampler/fsm/fsm_utils/src/types.rs index e3e7af3..16ade05 100644 --- a/function_sampler/fsm/fsm_utils/src/types.rs +++ b/function_sampler/fsm/fsm_utils/src/types.rs @@ -42,8 +42,9 @@ pub struct PyFSMInfo { transitions: HashMap<(u32, u32), u32>, //#[pyo3(item("trans_key_to_states"))] //trans_key_to_states: HashMap>, - //#[pyo3(item("alphabet_anything_value"))] - //alphabet_anything_value: u32, + #[pyo3(item("alphabet_anything_value"))] + alphabet_anything_value: u32, + #[pyo3(item("alphabet_symbol_mapping"))] alphabet_symbol_mapping: HashMap, } @@ -80,11 +81,14 @@ impl TryFrom<&PyFSMInfo> for FSMInfo { states.insert(*to); } + let alphabet_anything_value = py_info.alphabet_anything_value; + Ok(FSMInfo { initial, finals, transitions, alphabet_symbol_mapping, + alphabet_anything_value, states, }) } @@ -105,6 +109,8 @@ pub struct FSMInfo { /// The alphabet mapping. /// key is a String representing the input, value is its u32 identifier / transition key. pub alphabet_symbol_mapping: BTreeMap, + + pub alphabet_anything_value: u32, pub states: BTreeSet, }