Skip to content

Commit

Permalink
Introduce Rust base vocabulary type
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin committed Sep 20, 2024
1 parent 54580f3 commit 96af432
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 3 deletions.
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ pub mod regex;
mod python_bindings;

mod primitives;
pub use primitives::{State, TokenId, TransitionKey};
pub use primitives::{State, Token, TokenId, TransitionKey};

mod vocabulary;
pub use vocabulary::Vocabulary;

pub(crate) use {std::collections::HashMap, std::ops::Deref};
3 changes: 3 additions & 0 deletions src/primitives.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
/// Interegular transition key.
pub type TransitionKey = u32;

/// Token content.
pub type Token = String;

/// Token identifier.
pub type TokenId = u32;

Expand Down
4 changes: 4 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ pub fn state_scan_tokens_py(
vocabulary_transition_keys: Vec<Vec<TransitionKey>>,
start_state: State,
) -> PyResult<HashSet<(TokenId, State)>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(state_scan_tokens(
&fsm_transitions,
fsm_initial,
Expand Down Expand Up @@ -131,6 +132,7 @@ pub fn get_vocabulary_transition_keys_py(
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Vec<Vec<TransitionKey>>> {
let vocabulary = Vocabulary::from_iter(vocabulary);
Ok(get_vocabulary_transition_keys(
&alphabet_symbol_mapping,
alphabet_anything_value,
Expand All @@ -147,6 +149,8 @@ pub fn create_fsm_index_end_to_end_py<'py>(
vocabulary: Vec<(String, Vec<TokenId>)>,
frozen_tokens: HashSet<String>,
) -> PyResult<Bound<'py, PyDict>> {
let vocabulary = Vocabulary::from_iter(vocabulary);

let states_to_token_subsets = PyDict::new_bound(py);
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);
Expand Down
4 changes: 2 additions & 2 deletions src/regex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn state_scan_tokens(
fsm_transitions: &HashMap<(State, TransitionKey), State>,
fsm_initial: State,
fsm_finals: &HashSet<State>,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary: &Vocabulary,
vocabulary_transition_keys: &[Vec<TransitionKey>],
start_state: State,
) -> HashSet<(TokenId, State)> {
Expand Down Expand Up @@ -110,7 +110,7 @@ pub fn get_token_transition_keys(
pub fn get_vocabulary_transition_keys(
alphabet_symbol_mapping: &HashMap<String, TransitionKey>,
alphabet_anything_value: TransitionKey,
vocabulary: &[(String, Vec<TokenId>)],
vocabulary: &Vocabulary,
frozen_tokens: &HashSet<String>,
) -> Vec<Vec<TransitionKey>> {
let mut vocab_transition_keys: Vec<Vec<TransitionKey>> = Vec::new();
Expand Down
101 changes: 101 additions & 0 deletions src/vocabulary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate::*;

/// Vocabulary of an LLM.
///
/// ## Examples
///
/// ```rust
/// # use outlines_core::*;
/// #
/// let vocabulary = Vocabulary::new()
/// .insert(0, "blah")
/// .insert(1, "1a")
/// .insert(2, "2")
/// .insert(3, "0");
/// ```
#[derive(Clone, Debug, Default)]
pub struct Vocabulary(HashMap<Token, Vec<TokenId>>);

impl Vocabulary {
/// Creates an empty vocabulary.
pub fn new() -> Vocabulary {
Vocabulary::default()
}
}

impl Vocabulary {
/// Inserts a token to the vocabulary with the specified identifier.
pub fn insert(mut self, id: TokenId, token: impl Into<Token>) -> Vocabulary {
let token = token.into();
self.0.entry(token).or_default().push(id);
self
}

/// Extends the vocabulary with tokens and their identifiers.
pub fn extend<T: Into<Token>, I: IntoIterator<Item = TokenId>>(
mut self,
tokens_and_ids: impl IntoIterator<Item = (T, I)>,
) -> Vocabulary {
for (token, ids) in tokens_and_ids.into_iter() {
let token = token.into();
for id in ids {
self = self.insert(id, token.clone());
}
}
self
}
}

impl Deref for Vocabulary {
type Target = HashMap<Token, Vec<TokenId>>;

fn deref(&self) -> &HashMap<Token, Vec<TokenId>> {
&self.0
}
}

impl<T, I> FromIterator<(T, I)> for Vocabulary
where
T: Into<Token>,
I: IntoIterator<Item = TokenId>,
{
fn from_iter<A: IntoIterator<Item = (T, I)>>(tokens_and_ids: A) -> Self {
Vocabulary::new().extend(tokens_and_ids)
}
}

#[cfg(test)]
mod tests {
use crate::*;

#[test]
fn insert() {
let vocabulary = Vocabulary::new()
.insert(0, "blah")
.insert(1, "1a")
.insert(2, "2")
.insert(3, "0");

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
}

#[test]
fn extend() {
let vocabulary = Vocabulary::new().extend([
("blah", vec![0]),
("1a", vec![1]),
("2", vec![2]),
("0", vec![3]),
]);

assert_eq!(vocabulary.len(), 4);
assert_eq!(vocabulary["blah"], &[0]);
assert_eq!(vocabulary["1a"], &[1]);
assert_eq!(vocabulary["2"], &[2]);
assert_eq!(vocabulary["0"], &[3]);
}
}

0 comments on commit 96af432

Please sign in to comment.