From 26cf28e17a34b2aeaf59fb295d1aae6c9d3ebf28 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 20 Sep 2024 22:52:21 +0300 Subject: [PATCH] Introduce Rust base vocabulary type --- src/lib.rs | 7 ++- src/primitives.rs | 3 ++ src/python_bindings/mod.rs | 4 ++ src/regex.rs | 4 +- src/vocabulary.rs | 101 +++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 src/vocabulary.rs diff --git a/src/lib.rs b/src/lib.rs index affc408..5c7b632 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,4 +5,9 @@ pub mod regex; mod python_bindings; mod primitives; -pub use primitives::{State, TokenId, TransitionKey}; +pub use primitives::{State, Token, TokenId, TransitionKey}; + +mod vocabulary; +pub use vocabulary::Vocabulary; + +pub(crate) use {std::collections::HashMap, std::ops::Deref}; diff --git a/src/primitives.rs b/src/primitives.rs index bbc7770..e12bf03 100644 --- a/src/primitives.rs +++ b/src/primitives.rs @@ -1,6 +1,9 @@ /// Interegular transition key. pub type TransitionKey = u32; +/// Token content. +pub type Token = String; + /// Token identifier. pub type TokenId = u32; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 251300f..34cbf4f 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -97,6 +97,7 @@ pub fn state_scan_tokens_py( vocabulary_transition_keys: Vec>, start_state: State, ) -> PyResult> { + let vocabulary = Vocabulary::from_iter(vocabulary); Ok(state_scan_tokens( &fsm_transitions, fsm_initial, @@ -131,6 +132,7 @@ pub fn get_vocabulary_transition_keys_py( vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult>> { + let vocabulary = Vocabulary::from_iter(vocabulary); Ok(get_vocabulary_transition_keys( &alphabet_symbol_mapping, alphabet_anything_value, @@ -147,6 +149,8 @@ pub fn create_fsm_index_end_to_end_py<'py>( vocabulary: Vec<(String, Vec)>, frozen_tokens: HashSet, ) -> PyResult> { + let vocabulary = Vocabulary::from_iter(vocabulary); + let states_to_token_subsets = PyDict::new_bound(py); let mut seen: HashSet = HashSet::new(); let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); diff --git a/src/regex.rs b/src/regex.rs index 314ff0c..c0eda76 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -42,7 +42,7 @@ pub fn state_scan_tokens( fsm_transitions: &HashMap<(State, TransitionKey), State>, fsm_initial: State, fsm_finals: &HashSet, - vocabulary: &[(String, Vec)], + vocabulary: &Vocabulary, vocabulary_transition_keys: &[Vec], start_state: State, ) -> HashSet<(TokenId, State)> { @@ -110,7 +110,7 @@ pub fn get_token_transition_keys( pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: TransitionKey, - vocabulary: &[(String, Vec)], + vocabulary: &Vocabulary, frozen_tokens: &HashSet, ) -> Vec> { let mut vocab_transition_keys: Vec> = Vec::new(); diff --git a/src/vocabulary.rs b/src/vocabulary.rs new file mode 100644 index 0000000..f926339 --- /dev/null +++ b/src/vocabulary.rs @@ -0,0 +1,101 @@ +use crate::*; + +/// Vocabulary of an LLM. +/// +/// ## Examples +/// +/// ```rust +/// # use outlines_core::*; +/// # +/// let vocabulary = Vocabulary::new() +/// .insert(0, "blah") +/// .insert(1, "1a") +/// .insert(2, "2") +/// .insert(3, "0"); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct Vocabulary(HashMap>); + +impl Vocabulary { + /// Creates an empty vocabulary. + pub fn new() -> Vocabulary { + Vocabulary::default() + } +} + +impl Vocabulary { + /// Inserts a token to the vocabulary with the specified identifier. + pub fn insert(mut self, id: TokenId, token: impl Into) -> Vocabulary { + let token = token.into(); + self.0.entry(token).or_default().push(id); + self + } + + /// Extends the vocabulary with tokens and their identifiers. + pub fn extend, I: IntoIterator>( + mut self, + tokens_and_ids: impl IntoIterator, + ) -> Vocabulary { + for (token, ids) in tokens_and_ids.into_iter() { + let token = token.into(); + for id in ids { + self = self.insert(id, token.clone()); + } + } + self + } +} + +impl Deref for Vocabulary { + type Target = HashMap>; + + fn deref(&self) -> &HashMap> { + &self.0 + } +} + +impl FromIterator<(T, I)> for Vocabulary +where + T: Into, + I: IntoIterator, +{ + fn from_iter>(tokens_and_ids: A) -> Self { + Vocabulary::new().extend(tokens_and_ids) + } +} + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn insert() { + let vocabulary = Vocabulary::new() + .insert(0, "blah") + .insert(1, "1a") + .insert(2, "2") + .insert(3, "0"); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } + + #[test] + fn extend() { + let vocabulary = Vocabulary::new().extend([ + ("blah", vec![0]), + ("1a", vec![1]), + ("2", vec![2]), + ("0", vec![3]), + ]); + + assert_eq!(vocabulary.len(), 4); + assert_eq!(vocabulary["blah"], &[0]); + assert_eq!(vocabulary["1a"], &[1]); + assert_eq!(vocabulary["2"], &[2]); + assert_eq!(vocabulary["0"], &[3]); + } +}