Skip to content

Commit

Permalink
Test Index from regex in Guide
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 12, 2024
1 parent 97c598e commit ead67fb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
16 changes: 13 additions & 3 deletions python/outlines_core/fsm/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
create_fsm_index_tokenizer,
make_byte_level_fsm,
make_deterministic_fsm,
reduced_vocabulary,
)

from .outlines_core_rs import Index
from .outlines_core_rs import Index, Vocabulary


@dataclass(frozen=True)
Expand Down Expand Up @@ -137,8 +138,17 @@ def create_states_mapping(
final_states:
A set of final states in the FSM.
"""
regex_fsm = regex_parser(regex_string).to_fsm()
return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens)
# regex_fsm = regex_parser(regex_string).to_fsm()
# return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens)

# inlining logic of create_fsm_index_tokenizer
tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer)
vocabulary = Vocabulary.from_dict_with_eos_token_id(
tokens_to_token_ids, tokenizer.eos_token_id
)
index = Index.from_regex(regex_string, vocabulary)

return index, empty_token_ids, set(index.final_states())


def create_states_mapping_from_fsm(
Expand Down
17 changes: 17 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ class Vocabulary:
Creates a vocabulary from a dictionary of tokens to token IDs.
"""
...
@staticmethod
def from_dict_with_eos_token_id(
map: Dict[str, List[int]], eos_token_id: int
) -> "Vocabulary":
"""
Creates a vocabulary from a dictionary of tokens to token IDs and eos token id.
"""
...
def __repr__(self) -> str:
"""
Gets the debug string representation of the vocabulary.
Expand All @@ -88,6 +96,12 @@ class Vocabulary:
...

class Index:
@staticmethod
def from_regex(regex: str, vocabulary: "Vocabulary") -> "Index":
"""
Creates an index from a regex and vocabulary.
"""
...
def get_allowed_tokens(self, state: int) -> Optional[List[int]]:
"""Returns allowed tokens in this state."""
...
Expand All @@ -97,6 +111,9 @@ class Index:
def is_final_state(self, state: int) -> bool:
"""Determines whether the current state is a final state."""
...
def final_states(self) -> List[int]:
"""Get all final states."""
...
def get_index_dict(self) -> Dict[int, Dict[int, int]]:
"""Returns the Index as a Python Dict object."""
...
Expand Down
22 changes: 22 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ impl PyIndex {
})
}

#[staticmethod]
fn from_regex(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult<Self> {
py.allow_threads(|| {
Index::from_regex(regex, &vocabulary.0)
.map(PyIndex)
.map_err(Into::into)
})
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Python::with_gil(|py| {
let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")?
Expand Down Expand Up @@ -126,6 +135,10 @@ impl PyIndex {
self.0.is_final(state)
}

fn final_states(&self) -> FxHashSet<State> {
self.0.final_states().clone()
}

fn get_transitions(&self) -> FxHashMap<u32, FxHashMap<u32, u32>> {
self.0.transitions().clone()
}
Expand Down Expand Up @@ -291,6 +304,15 @@ impl PyVocabulary {
PyVocabulary(Vocabulary::from(map))
}

#[staticmethod]
fn from_dict_with_eos_token_id(
map: FxHashMap<Token, Vec<TokenId>>,
eos_token_id: TokenId,
) -> PyVocabulary {
let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id));
PyVocabulary(v)
}

fn __repr__(&self) -> String {
format!("{:#?}", self.0)
}
Expand Down

0 comments on commit ead67fb

Please sign in to comment.