Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Rust based Vocabulary type #30

Merged
merged 3 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def state_scan_tokens(
fsm_initial: int,
fsm_finals: Set[int],
vocabulary: List[Tuple[str, List[int]]],
vocabulary_transition_keys: List[List[int]],
vocabulary_transition_keys: Dict[str, List[int]],
start_state: int,
) -> Set[Tuple[int, int]]: ...
def get_token_transition_keys(
Expand All @@ -46,7 +46,7 @@ def get_vocabulary_transition_keys(
alphabet_anything_value: int,
vocabulary: List[Tuple[str, List[int]]],
frozen_tokens: Set[str],
) -> List[List[int]]: ...
) -> Dict[str, List[int]]: ...
def create_fsm_index_end_to_end(
fsm_info: FSMInfo,
vocabulary: List[Tuple[str, List[int]]],
Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ pub mod regex;
mod python_bindings;

mod primitives;
pub use primitives::{State, Token, TokenId, TransitionKey};

pub use crate::primitives::{State, TokenId, TransitionKey};
mod vocabulary;
pub use vocabulary::Vocabulary;

pub(crate) use {std::collections::HashMap, std::ops::Deref};
3 changes: 3 additions & 0 deletions src/primitives.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/// Interegular transition key.
pub type TransitionKey = u32;

/// Token content.
pub type Token = String;

/// Token identifier.
pub type TokenId = u32;

Expand Down
8 changes: 6 additions & 2 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ pub fn state_scan_tokens_py(
fsm_initial: State,
fsm_finals: HashSet<State>,
vocabulary: Vec<(String, Vec<TokenId>)>,
vocabulary_transition_keys: Vec<Vec<TransitionKey>>,
vocabulary_transition_keys: HashMap<String, Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand Down Expand Up @@ -130,7 +131,8 @@ pub fn get_vocabulary_transition_keys_py(
alphabet_anything_value: TransitionKey,
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<TransitionKey>>> {
) -> PyResult<HashMap<String, Vec<TransitionKey>>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -147,6 +149,8 @@ pub fn create_fsm_index_end_to_end_py<'py>(
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let vocabulary = Vocabulary::from_iter(vocabulary);

let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);
Expand Down
21 changes: 9 additions & 12 deletions src/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,14 @@ pub fn state_scan_tokens(
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &HashSet<State>,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary_transition_keys: &[Vec<TransitionKey>],
vocabulary: &Vocabulary,
vocabulary_transition_keys: &HashMap<Token, Vec<TransitionKey>>,
start_state: State,
) -> HashSet<(TokenId, State)> {
let mut res = HashSet::new();

for (vocab_item, token_transition_keys) in
vocabulary.iter().zip(vocabulary_transition_keys.iter())
{
let token_ids: Vec<TokenId> = vocab_item.1.clone();

for (token, token_ids) in vocabulary.iter() {
let token_transition_keys = &vocabulary_transition_keys[token];
let state_seq = walk_fsm(
fsm_transitions,
fsm_initial,
Expand All @@ -66,7 +63,7 @@ pub fn state_scan_tokens(
continue;
}

for &token_id in &token_ids {
for &token_id in token_ids {
res.insert((token_id, *state_seq.last().unwrap()));
}
}
Expand Down Expand Up @@ -110,10 +107,10 @@ pub fn get_token_transition_keys(
pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary: &Vocabulary,
frozen_tokens: &HashSet<String>,
) -> Vec<Vec<TransitionKey>> {
let mut vocab_transition_keys: Vec<Vec<TransitionKey>> = Vec::new();
) -> HashMap<Token, Vec<TransitionKey>> {
let mut vocab_transition_keys = HashMap::new();

for item in vocabulary.iter() {
let token_str = item.0.clone();
Expand All @@ -137,7 +134,7 @@ pub fn get_vocabulary_transition_keys(
);
}

vocab_transition_keys.push(token_transition_keys);
vocab_transition_keys.insert(token_str, token_transition_keys);
}

vocab_transition_keys
Expand Down
101 changes: 101 additions & 0 deletions src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate::*;

/// Vocabulary of an LLM.
///
/// ## Examples
///
/// ```rust
/// # use outlines_core::*;
/// #
/// let vocabulary = Vocabulary::new()
/// .insert(0, "blah")
/// .insert(1, "1a")
/// .insert(2, "2")
/// .insert(3, "0");
/// ```
#[derive(Clone, Debug, Default)]
pub struct Vocabulary(HashMap<Token, Vec<TokenId>>);

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new() -> Vocabulary {
Vocabulary::default()
}
}

impl Vocabulary {
/// Inserts a token to the vocabulary with the specified identifier.
pub fn insert(mut self, id: TokenId, token: impl Into<Token>) -> Vocabulary {
let token = token.into();
self.0.entry(token).or_default().push(id);
self
}

/// Extends the vocabulary with tokens and their identifiers.
pub fn extend<T: Into<Token>, I: IntoIterator<Item = TokenId>>(
mut self,
tokens_and_ids: impl IntoIterator<Item = (T, I)>,
) -> Vocabulary {
for (token, ids) in tokens_and_ids.into_iter() {
let token = token.into();
for id in ids {
self = self.insert(id, token.clone());
}
}
self
}
}

impl Deref for Vocabulary {
type Target = HashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.0
}
}

impl<T, I> FromIterator<(T, I)> for Vocabulary
where
T: Into<Token>,
I: IntoIterator<Item = TokenId>,
{
fn from_iter<A: IntoIterator<Item = (T, I)>>(tokens_and_ids: A) -> Self {
Vocabulary::new().extend(tokens_and_ids)
}
}

#[cfg(test)]
mod tests {
use crate::*;

#[test]
fn insert() {
let vocabulary = Vocabulary::new()
.insert(0, "blah")
.insert(1, "1a")
.insert(2, "2")
.insert(3, "0");

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
}

#[test]
fn extend() {
let vocabulary = Vocabulary::new().extend([
("blah", vec![0]),
("1a", vec![1]),
("2", vec![2]),
("0", vec![3]),
]);

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
}
}
15 changes: 3 additions & 12 deletions tests/fsm/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,13 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_vocabulary_transition_keys(
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
frozenset(),
)

token_str_to_tranition_keys = {
token_str: trans_key_seq
for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys)
}
# `a` and `b` both are workable, but `z` has distinct transition rules
assert interegular_fsm.accepts("zaz")
assert interegular_fsm.accepts("zbz")
Expand All @@ -470,22 +466,17 @@ def convert_token_to_string(self, token):
interegular_fsm = regex_pattern.to_fsm().reduce()
regex_fsm, _ = make_deterministic_fsm(interegular_fsm)
vocabulary, _ = reduced_vocabulary(tokenizer)
token_trans_keys = get_vocabulary_transition_keys(
token_str_to_tranition_keys = get_vocabulary_transition_keys(
regex_fsm.fsm_info.alphabet_symbol_mapping,
regex_fsm.fsm_info.alphabet_anything_value,
list(vocabulary.items()),
frozenset(),
)

token_str_trans_key_seq = {
token_str: trans_key_seq
for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys)
}

# verify initial state valid only for "ab" and "ac" using transition key seq
token_acceptance = {"ab": True, "ac": True, "az": False}
for token, should_accept in token_acceptance.items():
token_trans_key_seq = token_str_trans_key_seq[token]
token_trans_key_seq = token_str_to_tranition_keys[token]
state_seq = _walk_fsm(
regex_fsm.fsm_info.transitions,
regex_fsm.fsm_info.initial,
Expand Down
Loading