Skip to content

Commit

Permalink
Build Index from regex
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 12, 2024
1 parent 31ab9f1 commit 97c598e
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ bincode = "2.0.0-rc.3"
hf-hub = "=0.3.2"
tokenizers = { version = "=0.20.3", features = ["http"] }
rustc-hash = "2.1.0"
regex-automata = "0.4.9"

[features]
python-bindings = ["pyo3"]
Expand Down
23 changes: 10 additions & 13 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@ use thiserror::Error;
pub type Result<T, E = crate::Error> = std::result::Result<T, E>;

#[derive(Error, Debug)]
#[error("{0}")]
pub struct TokenizersError(pub tokenizers::Error);

impl PartialEq for TokenizersError {
fn eq(&self, other: &Self) -> bool {
self.0.to_string() == other.0.to_string()
}
}

#[derive(Error, Debug, PartialEq)]
pub enum Error {
#[error("The vocabulary does not allow us to build a sequence that matches the input")]
IndexError,
#[error("The vocabulary does not allow to build an index that matches the input")]
InsufficientVocabulary,
// TODO: this error will be removed once eos_token_id for vocabulary won't be optional
#[error("Index failed since vocabulary doesn't provide eos token id")]
IndexEosTokenIdNotAvailable,
#[error("Failed to build DFA {0}")]
IndexDfaError(#[from] Box<regex_automata::dfa::dense::BuildError>),
#[error("Index failed since anchored universal start state doesn't exist")]
IndexNoAnchoredUniversalStartState,
#[error(transparent)]
TokenizersError(#[from] TokenizersError),
TokenizersError(#[from] tokenizers::Error),
#[error("Unsupported tokenizer for {model}: {reason}, please open an issue with the full error message: https://github.com/dottxt-ai/outlines-core/issues")]
UnsupportedTokenizer { model: String, reason: String },
#[error("Unable to locate EOS token for {model}")]
Expand Down
124 changes: 122 additions & 2 deletions src/index.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
/// Construct an Index.
use crate::prelude::{State, TransitionKey};
use crate::prelude::*;
use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
use bincode::{Decode, Encode};
use regex_automata::dfa::{dense::DFA, Automaton};
use regex_automata::util::primitives::StateID as AutomataStateId;
use regex_automata::Anchored;
use rustc_hash::{FxHashMap, FxHashSet};

#[derive(Debug)]
Expand Down Expand Up @@ -101,7 +104,96 @@ impl Index {
eos_token_id,
})
} else {
Err(Error::IndexError)
Err(Error::InsufficientVocabulary)
}
}

pub(crate) fn from_regex(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let eos_token_id = match vocabulary.eos_token_id() {
Some(s) => s,
None => return Err(Error::IndexEosTokenIdNotAvailable),
};

let dfa = DFA::builder().build(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Some(s) => s,
None => return Err(Error::IndexNoAnchoredUniversalStartState),
};

let mut index: FxHashMap<State, FxHashMap<TokenId, State>> = FxHashMap::default();
let mut seen: FxHashSet<AutomataStateId> = FxHashSet::default();
let mut final_states: FxHashSet<State> = FxHashSet::default();
let mut next_states: FxHashSet<AutomataStateId> = FxHashSet::from_iter([start_state]);

while let Some(start_state) = next_states.iter().cloned().next() {
next_states.remove(&start_state);
seen.insert(start_state);

if dfa.is_match_state(dfa.next_eoi_state(start_state)) {
final_states.insert(start_state.as_u32());
}

'token_loop: for (token, ids) in vocabulary.tokens_to_ids().iter() {
if ids.contains(&eos_token_id) {
continue;
}

let mut next_state = start_state;
for transition_byte in token.as_bytes() {
next_state = dfa.next_state(next_state, *transition_byte);
if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
continue 'token_loop;
}
}

if dfa.is_match_state(next_state) {
// Token either matched or matched except the last character.
// Check what happens if the input suddenly ends after reaching this state.
// If the automata still matches, then token is exactly matched, if not
// then token didn't match.
let next_eoi_state = dfa.next_eoi_state(next_state);
let token_matched = dfa.is_match_state(next_eoi_state);
if !token_matched {
continue;
}
}

for token_id in ids {
let mapping = index.entry(start_state.as_u32()).or_default();
mapping.insert(*token_id, next_state.as_u32());

if !seen.contains(&next_state) {
next_states.insert(next_state);
}
}
}
}

let start_state = start_state.as_u32();

// Populate `index` with mappings from `final_states` to `eos_token_id`
for &final_state in &final_states {
index
.entry(final_state)
.or_default()
.insert(eos_token_id, final_state);
}
// Check if there is at least one valid mapping
let is_valid = index.values().any(|mapping| {
mapping
.values()
.any(|end_state| final_states.contains(end_state))
});

if is_valid {
Ok(Self {
initial: start_state,
finals: final_states,
states_to_token_subsets: index,
eos_token_id,
})
} else {
Err(Error::InsufficientVocabulary)
}
}

Expand All @@ -126,7 +218,35 @@ impl Index {
self.finals.contains(&state)
}

pub(crate) fn final_states(&self) -> &FxHashSet<State> {
&self.finals
}

pub(crate) fn transitions(&self) -> &FxHashMap<u32, FxHashMap<u32, u32>> {
&self.states_to_token_subsets
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn index_from_regex() {
let regex = "0|[1-9][0-9]*";
let vocabulary = Vocabulary::new(Some(4))
.insert("blah", 0)
.insert("1a", 1)
.insert("2", 2)
.insert("0", 3)
.insert("<eos>", 4);

let index = Index::from_regex(regex, &vocabulary).expect("Index failed");
assert_eq!(index.initial(), 40);
assert_eq!(index.final_states(), &FxHashSet::from_iter([24, 48, 56]));
assert_eq!(
"{24: {3: 24, 4: 24, 2: 24}, 48: {4: 48}, 40: {3: 48, 2: 56}, 56: {3: 24, 4: 56, 2: 24}}",
format!("{:?}", index.transitions())
);
}
}
17 changes: 14 additions & 3 deletions src/vocabulary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rustc_hash::FxHashMap;
use tokenizers::normalizers::Sequence;
use tokenizers::{FromPretrainedParameters, NormalizerWrapper, Tokenizer};

use crate::{error, prelude::*};
use crate::prelude::*;
use crate::{Error, Result};

use locator::{HFLocator, Locator};
Expand Down Expand Up @@ -41,6 +41,13 @@ impl Vocabulary {
}
}

pub fn with_eos_token_id(self, eos_token_id: Option<TokenId>) -> Self {
Self {
eos_token_id,
..self
}
}

/// Creates the vocabulary of pre-trained model from Hugging Face Hub.
pub fn from_pretrained(
model: &str,
Expand All @@ -55,8 +62,7 @@ impl Vocabulary {
model: &str,
parameters: Option<FromPretrainedParameters>,
) -> Result<Self> {
let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone())
.map_err(|e| Error::TokenizersError(error::TokenizersError(e)))?;
let mut tokenizer = Tokenizer::from_pretrained(model, parameters.clone())?;
Self::filter_prepend_normalizers(&mut tokenizer);

// Locate eos_token_id in defined locations.
Expand Down Expand Up @@ -95,6 +101,11 @@ impl Vocabulary {
Ok(vocabulary)
}

/// Returns all tokens with their token ids in vocabulary
pub fn tokens_to_ids(&self) -> &FxHashMap<Token, Vec<TokenId>> {
&self.tokens
}

/// Per provided token returns vector of `TokenId`s if available in the vocabulary.
pub fn token_to_ids(&self, token: &str) -> Option<&Vec<TokenId>> {
self.tokens.get(token)
Expand Down

0 comments on commit 97c598e

Please sign in to comment.