Skip to content

Commit

Permalink
fix fallback transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
unaidedelf8777 committed May 6, 2024
1 parent 5bb4bff commit 8d33bca
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 40 deletions.
58 changes: 20 additions & 38 deletions function_sampler/fsm/fsm_utils/src/tokenizer_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,53 +53,35 @@ fn walk_fsm(
) -> Vec<u32> {
let mut state = start_state;
let mut accepted_states = Vec::new();
let mut last_final_idx: Option<usize> = None;
let mut last_final_idx = 0;

let mut current_pos = 0;
let input_chars: Vec<char> = input_string.chars().collect();
for (i, symbol) in input_string.chars().enumerate() {
let trans_key = fsm_info
.alphabet_symbol_mapping
.get(&symbol.to_string())
.unwrap_or(&fsm_info.alphabet_anything_value);

while current_pos < input_chars.len() {
let mut found = false;
let new_state = fsm_info.transitions.get(&(state, *trans_key));

// Attempt to match longer substrings first, ensuring multi-character sequences are prioritized
for len in (1..=input_chars.len() - current_pos).rev() {
let possible_match: String =
input_chars[current_pos..current_pos + len].iter().collect();

if let Some(&trans_key) = fsm_info.alphabet_symbol_mapping.get(&possible_match) {
if let Some(&new_state) = fsm_info.transitions.get(&(state, trans_key)) {
state = new_state;
if fsm_info.finals.contains(&state) {
last_final_idx = Some(accepted_states.len() + 1);
}
accepted_states.push(state);
current_pos += len; // Move past the matched substring
found = true;
break;
}
if let Some(&new_state) = new_state {
state = new_state;
if fsm_info.finals.contains(&state) {
last_final_idx = i + 1; // Store the index of the last final state encountered
}
}

if !found {
if !full_match && last_final_idx.is_some() {
// Non-full match and we've previously encountered a final state
return accepted_states
.into_iter()
.take(last_final_idx.unwrap())
.collect();
} else {
// No match found, or a full match is required
return vec![];
accepted_states.push(state);
} else {
if !full_match && last_final_idx > 0 {
return accepted_states.into_iter().take(last_final_idx).collect();
}
return Vec::new();
}
}

// Full match checks
if full_match && last_final_idx.map_or(true, |idx| idx != accepted_states.len()) {
return vec![]; // Full match required but last character didn't result in a final state
if full_match && last_final_idx - 1 != input_string.chars().count() - 1 {
Vec::new() // If full match is required and last final state is not at the end, return empty
} else {
accepted_states
}

accepted_states
}

/// This function scans a set of tokens against an FSM to determine the resulting states from a given start state.
Expand Down
10 changes: 8 additions & 2 deletions function_sampler/fsm/fsm_utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ pub struct PyFSMInfo {
transitions: HashMap<(u32, u32), u32>,
//#[pyo3(item("trans_key_to_states"))]
//trans_key_to_states: HashMap<u32, Vec<u32>>,
//#[pyo3(item("alphabet_anything_value"))]
//alphabet_anything_value: u32,
#[pyo3(item("alphabet_anything_value"))]
alphabet_anything_value: u32,

#[pyo3(item("alphabet_symbol_mapping"))]
alphabet_symbol_mapping: HashMap<String, u32>,
}
Expand Down Expand Up @@ -80,11 +81,14 @@ impl TryFrom<&PyFSMInfo> for FSMInfo {
states.insert(*to);
}

let alphabet_anything_value = py_info.alphabet_anything_value;

Ok(FSMInfo {
initial,
finals,
transitions,
alphabet_symbol_mapping,
alphabet_anything_value,
states,
})
}
Expand All @@ -105,6 +109,8 @@ pub struct FSMInfo {
/// The alphabet mapping.
/// key is a String representing the input, value is its u32 identifier / transition key.
pub alphabet_symbol_mapping: BTreeMap<String, u32>,

pub alphabet_anything_value: u32,
pub states: BTreeSet<u32>,
}

Expand Down

0 comments on commit 8d33bca

Please sign in to comment.