From 34268d67ce669b73c9db27cba33ce32705616348 Mon Sep 17 00:00:00 2001 From: Kaustubh Date: Tue, 29 Oct 2024 21:31:22 +0530 Subject: [PATCH] Add a Rust-first `Index` object --- .github/workflows/tests.yml | 2 +- Cargo.toml | 1 + python/outlines_core/fsm/guide.py | 85 ++++------- python/outlines_core/fsm/outlines_core_rs.pyi | 17 +++ python/outlines_core/fsm/regex.py | 14 +- src/index.rs | 132 ++++++++++++++++++ src/lib.rs | 21 ++- src/python_bindings/mod.rs | 81 +++++++++-- src/vocabulary.rs | 2 +- tests/fsm/test_guide.py | 28 ++-- tests/fsm/test_regex.py | 2 +- 11 files changed, 290 insertions(+), 95 deletions(-) create mode 100644 src/index.rs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d50a5578..25eb1496 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: echo "::set-output name=id::$MATRIX_ID" - name: Run tests run: | - pytest --cov=outlines_core + pytest --cov=outlines_core -vv env: COVERAGE_FILE: .coverage.${{ steps.matrix-id.outputs.id }} - name: Upload coverage data diff --git a/Cargo.toml b/Cargo.toml index bd153aff..0e83d020 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ repository = "https://github.com/dottxt-ai/outlines-core" [dependencies] anyhow = "1.0.86" +thiserror = "1.0" pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true } regex = "1.10.6" serde-pyobject = "0.4.0" diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index 27605402..c7aa0012 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union +from typing import Any, Callable, List, Optional, Protocol, Set, Tuple, Union import interegular import torch @@ -9,6 +9,8 @@ make_deterministic_fsm, ) +from .outlines_core_rs import Index + @dataclass(frozen=True) class Write: @@ -107,7 +109,7 @@ def create_states_mapping( tokenizer, regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern, frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: +) -> Tuple[Index, Set[int], Set[int]]: """Create the variables related to the mapping between states and tokens from a regex string. The parameters of the function are used for caching purpose. @@ -143,7 +145,7 @@ def create_states_mapping_from_fsm( fsm: interegular.fsm.FSM, tokenizer, frozen_tokens: List[str] = [], -) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]: +) -> Tuple[Index, Set[int], Set[int]]: """Create the variables related to the mapping between states and tokens from an FSM. The parameters of the function are used for caching purpose. @@ -177,16 +179,6 @@ def create_states_mapping_from_fsm( regex_fsm, tokenizer ) - # We make sure that it is possible to generate strings in the language - # of the regular expression with the tokens present in the model's - # vocabulary. - if not any( - regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values() - ): - raise ValueError( - "The vocabulary does not allow us to build a sequence that matches the input regex" - ) - return states_to_token_maps, empty_token_ids, regex_fsm.finals @@ -196,18 +188,12 @@ class RegexGuide(Guide): initial_state = 0 def __init__( - self, - states_to_token_maps, - empty_token_ids, - fsm_finals, - eos_token_id, - states_to_token_mask, + self, states_to_token_maps, empty_token_ids, eos_tensor, initial_state ): self.states_to_token_maps = states_to_token_maps self.empty_token_ids = empty_token_ids - self.eos_token_id = eos_token_id - self.final_states = fsm_finals | {-1} - self.states_to_token_mask = states_to_token_mask + self.eos_tensor = eos_tensor + self.initial_state = initial_state @classmethod def from_regex( @@ -229,17 +215,9 @@ def from_regex( regex_parser=regex_parser, frozen_tokens=frozen_tokens, ) - states_to_token_mask = { - state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device) - for state, next_tokens_to_end_states in states_to_token_maps.items() - } - return cls( - states_to_token_maps, - empty_token_ids, - fsm_finals, - tokenizer.eos_token_id, - states_to_token_mask, - ) + eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) + initial_state = states_to_token_maps.get_initial_state() + return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) @classmethod def from_interegular_fsm( @@ -257,17 +235,9 @@ def from_interegular_fsm( ) = _create_states_mapping_from_fsm( interegular_fsm, tokenizer, frozen_tokens=frozen_tokens ) - states_to_token_mask = { - state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device) - for state, next_tokens_to_end_states in states_to_token_maps.items() - } - return cls( - states_to_token_maps, - empty_token_ids, - fsm_finals, - tokenizer.eos_token_id, - states_to_token_mask, - ) + eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) + initial_state = states_to_token_maps.get_initial_state() + return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -292,11 +262,14 @@ def get_next_instruction(self, state: int) -> Instruction: A `Generate` instance that contains the model and the allowed token ids. """ - next_tokens_mask = self.states_to_token_mask.get(state) + if state == -1: + return Write(self.eos_tensor) + next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state) + # TODO: Create the Write and Generate objects within Rust instead? if next_tokens_mask is None: - return Write(torch.tensor([self.eos_token_id])) + return Write(self.eos_tensor) - return Generate(next_tokens_mask) + return Generate(torch.tensor(next_tokens_mask)) def get_next_state(self, state: int, token_id: int) -> int: """Update the state of the guide. @@ -316,19 +289,21 @@ def get_next_state(self, state: int, token_id: int) -> int: The new state of the guide. """ - if token_id == self.eos_token_id or state not in self.states_to_token_maps: + if state == -1: return -1 - - last_token_to_end_state = self.states_to_token_maps[state] - next_state = last_token_to_end_state.get(token_id) + next_state = self.states_to_token_maps.get_next_state(state, token_id) if next_state is None: - next_state = -1 - - return next_state + return -1 + else: + return next_state def is_final_state(self, state: int) -> bool: """Determine whether the current state of the guide is a final state.""" - return state in self.final_states + return state == -1 or self.states_to_token_maps.is_final_state(state) def copy(self): return self + + def get_index_dict(self): + """Returns the Index as a Python Dict object.""" + return self.states_to_token_maps.get_index_dict() diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index f2fb3de2..2bd6c82b 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -86,3 +86,20 @@ class Vocabulary: Gets the string representation of the vocabulary. """ ... + +class Index: + def get_allowed_tokens(self, state: int) -> Optional[List[int]]: + """Returns allowed tokens in this state.""" + ... + def get_next_state(self, state: int, token_id: int) -> Optional[int]: + """Updates the state.""" + ... + def is_final_state(self, state: int) -> bool: + """Determines whether the current state is a final state.""" + ... + def get_index_dict(self) -> Dict[int, Dict[int, int]]: + """Returns the Index as a Python Dict object.""" + ... + def get_initial_state(self) -> int: + """Returns the ID of the initial state of the input FSM automata.""" + ... diff --git a/python/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py index d0935c8f..af337e34 100644 --- a/python/outlines_core/fsm/regex.py +++ b/python/outlines_core/fsm/regex.py @@ -24,6 +24,7 @@ from .outlines_core_rs import ( # noqa: F401 FSMInfo, + Index, Vocabulary, _walk_fsm, create_fsm_index_end_to_end, @@ -438,7 +439,7 @@ def create_fsm_index_tokenizer( fsm: BetterFSM, tokenizer, frozen_tokens: Optional[Iterable[str]] = None, -) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: +) -> Tuple[Index, Set[int]]: """Construct an FMS index from a tokenizer. This uses the end-to-end approach of `create_fsm_index_end_to_end`. @@ -469,18 +470,11 @@ def create_fsm_index_tokenizer( """ tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) - states_to_token_subsets = create_fsm_index_end_to_end( + states_to_token_subsets = Index( # type: ignore fsm.fsm_info, Vocabulary.from_dict(tokens_to_token_ids), + tokenizer.eos_token_id, frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(), ) - # Allow transitions to EOS from all terminals FSM states that are - # reachable - # TODO: Do we really need this anymore? - for state in fsm.fsm_info.finals: - subset = states_to_token_subsets.get(state) - if subset is not None: - subset[tokenizer.eos_token_id] = state - return states_to_token_subsets, empty_token_ids diff --git a/src/index.rs b/src/index.rs new file mode 100644 index 00000000..727061c1 --- /dev/null +++ b/src/index.rs @@ -0,0 +1,132 @@ +/// Construct an Index. +use crate::prelude::{State, TransitionKey}; +use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens}; +use crate::vocabulary::Vocabulary; +use std::collections::{HashMap, HashSet}; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub struct FSMInfo { + pub(crate) initial: State, + pub(crate) finals: HashSet, + pub(crate) transitions: HashMap<(State, TransitionKey), State>, + pub(crate) alphabet_anything_value: TransitionKey, + pub(crate) alphabet_symbol_mapping: HashMap, +} + +impl FSMInfo { + pub fn new( + initial: State, + finals: HashSet, + transitions: HashMap<(State, TransitionKey), State>, + alphabet_anything_value: TransitionKey, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} + +#[derive(Debug)] +pub struct Index { + pub(crate) initial: u32, + finals: HashSet, + states_to_token_subsets: HashMap>, + eos_token_id: u32, +} + +impl Index { + pub fn new( + fsm_info: &FSMInfo, + vocabulary: &Vocabulary, + eos_token_id: u32, + frozen_tokens: HashSet, + ) -> Result { + let mut states_to_token_subsets: HashMap> = HashMap::new(); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from([fsm_info.initial]); + + let vocabulary_transition_keys = get_vocabulary_transition_keys( + &fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + vocabulary, + &frozen_tokens, + ); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + + let token_ids_end_states = state_scan_tokens( + &fsm_info.transitions, + fsm_info.initial, + &fsm_info.finals, + vocabulary, + &vocabulary_transition_keys, + start_state, + ); + + for (token_id, end_state) in &token_ids_end_states { + let inner_map = states_to_token_subsets.entry(start_state).or_default(); + inner_map.insert(*token_id, *end_state); + + if !seen.contains(end_state) { + next_states.insert(*end_state); + } + } + + if fsm_info.finals.contains(&start_state) && !token_ids_end_states.is_empty() { + let inner_map = states_to_token_subsets.entry(start_state).or_default(); + inner_map.insert(eos_token_id, start_state); + } + + seen.insert(start_state); + } + + let is_valid = states_to_token_subsets + .values() + .flat_map(|token_id_end_states| token_id_end_states.values()) + .any(|end_state| fsm_info.finals.contains(end_state)); + + if is_valid { + Ok(Self { + initial: fsm_info.initial, + finals: fsm_info.finals.clone(), + states_to_token_subsets, + eos_token_id, + }) + } else { + Err(crate::Error::IndexError) + } + } + + pub(crate) fn allowed_tokens(&self, state: u32) -> Option> { + self.states_to_token_subsets + .get(&state) + .map_or_else(|| None, |res| Some(res.keys().cloned().collect())) + } + + pub(crate) fn next_state(&self, state: u32, token_id: u32) -> Option { + if token_id == self.eos_token_id { + return None; + } + Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?) + } + + pub(crate) fn initial(&self) -> u32 { + self.initial + } + + pub(crate) fn is_final(&self, state: u32) -> bool { + self.finals.contains(&state) + } + + pub(crate) fn index(&self) -> &HashMap> { + &self.states_to_token_subsets + } +} diff --git a/src/lib.rs b/src/lib.rs index 43a1f05e..71787e2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,25 @@ +pub mod index; pub mod json_schema; +pub mod prelude; +pub mod primitives; pub mod regex; +pub mod vocabulary; #[cfg(feature = "python-bindings")] mod python_bindings; -pub mod prelude; +use thiserror::Error; -pub mod primitives; -pub mod vocabulary; +#[derive(Error, Debug)] +pub enum Error { + #[error("The vocabulary does not allow us to build a sequence that matches the input")] + IndexError, +} + +#[cfg(feature = "python-bindings")] +impl From for pyo3::PyErr { + fn from(e: Error) -> Self { + use pyo3::{exceptions::PyValueError, PyErr}; + PyErr::new::(e.to_string()) + } +} diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 214d206d..368e8ced 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -1,3 +1,4 @@ +use crate::index::{FSMInfo, Index}; use crate::json_schema; use crate::prelude::*; use crate::regex::get_token_transition_keys; @@ -11,8 +12,8 @@ use pyo3::wrap_pyfunction; use serde_json::Value; use std::collections::{HashMap, HashSet}; -#[pyclass] -pub struct FSMInfo { +#[pyclass(name = "FSMInfo")] +pub struct PyFSMInfo { #[pyo3(get)] initial: State, #[pyo3(get)] @@ -25,8 +26,33 @@ pub struct FSMInfo { alphabet_symbol_mapping: HashMap, } +impl From for PyFSMInfo { + fn from(fsm_info: FSMInfo) -> Self { + PyFSMInfo { + initial: fsm_info.initial, + finals: fsm_info.finals, + transitions: fsm_info.transitions, + alphabet_anything_value: fsm_info.alphabet_anything_value, + alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping, + } + } +} + +// FIXME: could be costly, confirm if FSMInfo will actually be part of the interface +impl From<&PyFSMInfo> for FSMInfo { + fn from(fsm_info: &PyFSMInfo) -> Self { + FSMInfo { + initial: fsm_info.initial, + finals: fsm_info.finals.clone(), + transitions: fsm_info.transitions.clone(), + alphabet_anything_value: fsm_info.alphabet_anything_value, + alphabet_symbol_mapping: fsm_info.alphabet_symbol_mapping.clone(), + } + } +} + #[pymethods] -impl FSMInfo { +impl PyFSMInfo { #[new] fn new( initial: State, @@ -35,13 +61,52 @@ impl FSMInfo { alphabet_anything_value: TransitionKey, alphabet_symbol_mapping: HashMap, ) -> Self { - Self { + FSMInfo::new( initial, finals, transitions, alphabet_anything_value, alphabet_symbol_mapping, - } + ) + .into() + } +} + +#[pyclass(name = "Index")] +pub struct PyIndex(Index); + +#[pymethods] +impl PyIndex { + #[new] + fn new( + fsm_info: &PyFSMInfo, + vocabulary: &PyVocabulary, + eos_token_id: u32, + frozen_tokens: HashSet, + ) -> PyResult { + Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) + .map(PyIndex) + .map_err(Into::into) + } + + fn get_allowed_tokens(&self, state: u32) -> Option> { + self.0.allowed_tokens(state) + } + + fn get_next_state(&self, state: u32, token_id: u32) -> Option { + self.0.next_state(state, token_id) + } + + fn is_final_state(&self, state: u32) -> bool { + self.0.is_final(state) + } + + fn get_index_dict(&self) -> HashMap> { + self.0.index().clone() + } + + fn get_initial_state(&self) -> u32 { + self.0.initial() } } @@ -143,7 +208,7 @@ pub fn get_vocabulary_transition_keys_py( #[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] pub fn create_fsm_index_end_to_end_py<'py>( py: Python<'py>, - fsm_info: &FSMInfo, + fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, frozen_tokens: HashSet, ) -> PyResult> { @@ -218,8 +283,6 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; - m.add_class::()?; - m.add("BOOLEAN", json_schema::BOOLEAN)?; m.add("DATE", json_schema::DATE)?; m.add("DATE_TIME", json_schema::DATE_TIME)?; @@ -235,7 +298,9 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?; m.add_function(wrap_pyfunction!(to_regex_py, m)?)?; + m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/vocabulary.rs b/src/vocabulary.rs index 49579e0f..f03df8f7 100644 --- a/src/vocabulary.rs +++ b/src/vocabulary.rs @@ -14,7 +14,7 @@ use crate::prelude::*; /// .insert("0", 3); /// ``` #[derive(Clone, Debug, Default)] -pub struct Vocabulary(HashMap>); +pub struct Vocabulary(pub(crate) HashMap>); impl Vocabulary { /// Creates an empty vocabulary. diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 174063b0..905bfded 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -59,7 +59,15 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == {0: {1: 1}} + assert fsm.get_index_dict() == {0: {1: 1}} + + instruction = fsm.get_next_instruction(-1) + assert isinstance(instruction, Write) + assert_expected_tensor_ids(instruction.tokens, [3]) + + instruction = fsm.get_next_instruction(3) + assert isinstance(instruction, Write) + assert_expected_tensor_ids(instruction.tokens, [3]) instruction = fsm.get_next_instruction(0) assert isinstance(instruction, Generate) @@ -70,9 +78,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_from_fsm(): class MockTokenizer: @@ -89,7 +94,7 @@ def convert_token_to_string(self, token): interegular.parse_pattern(regex_str).to_fsm(), tokenizer ) - assert fsm.states_to_token_maps == {0: {1: 1}} + assert fsm.get_index_dict() == {0: {1: 1}} instruction = fsm.get_next_instruction(0) assert isinstance(instruction, Generate) @@ -100,9 +105,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_multi_byte_llama_like(): class MockTokenizer: @@ -130,7 +132,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == { + assert fsm.get_index_dict() == { 0: {5: 1, 4: 2}, 1: {6: 3}, 3: {7: 4}, @@ -146,9 +148,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_multi_byte_gpt2_like(): class MockTokenizer: @@ -177,7 +176,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == { + assert fsm.get_index_dict() == { 0: {5: 1, 10: 2}, 1: {8: 5, 4: 3}, 2: {11: 3}, @@ -193,9 +192,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_final_state(): """Make sure that the FSM stays in the final state as we keep generating""" diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index 15b7ed05..cdac64d4 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -372,7 +372,7 @@ def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): ) assert not empty_token_ids - assert len(states_to_token_subsets) / num_fsm_states > 0.94 + assert len(states_to_token_subsets.get_index_dict()) / num_fsm_states > 0.94 @pytest.mark.parametrize(