diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index 18b4523c..bb7561a7 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -7,9 +7,10 @@ create_fsm_index_tokenizer, make_byte_level_fsm, make_deterministic_fsm, + reduced_vocabulary, ) -from .outlines_core_rs import Index +from .outlines_core_rs import Index, Vocabulary @dataclass(frozen=True) @@ -137,8 +138,17 @@ def create_states_mapping( final_states: A set of final states in the FSM. """ - regex_fsm = regex_parser(regex_string).to_fsm() - return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) + # regex_fsm = regex_parser(regex_string).to_fsm() + # return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens) + + # inlining logic of create_fsm_index_tokenizer + tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer) + vocabulary = Vocabulary.from_dict_with_eos_token_id( + tokens_to_token_ids, tokenizer.eos_token_id + ) + index = Index.from_regex(regex_string, vocabulary) + + return index, empty_token_ids, set(index.final_states()) def create_states_mapping_from_fsm( diff --git a/python/outlines_core/fsm/outlines_core_rs.pyi b/python/outlines_core/fsm/outlines_core_rs.pyi index 2bd6c82b..e43a595f 100644 --- a/python/outlines_core/fsm/outlines_core_rs.pyi +++ b/python/outlines_core/fsm/outlines_core_rs.pyi @@ -76,6 +76,14 @@ class Vocabulary: Creates a vocabulary from a dictionary of tokens to token IDs. """ ... + @staticmethod + def from_dict_with_eos_token_id( + map: Dict[str, List[int]], eos_token_id: int + ) -> "Vocabulary": + """ + Creates a vocabulary from a dictionary of tokens to token IDs and eos token id. + """ + ... def __repr__(self) -> str: """ Gets the debug string representation of the vocabulary. @@ -88,6 +96,12 @@ class Vocabulary: ... class Index: + @staticmethod + def from_regex(regex: str, vocabulary: "Vocabulary") -> "Index": + """ + Creates an index from a regex and vocabulary. + """ + ... def get_allowed_tokens(self, state: int) -> Optional[List[int]]: """Returns allowed tokens in this state.""" ... @@ -97,6 +111,9 @@ class Index: def is_final_state(self, state: int) -> bool: """Determines whether the current state is a final state.""" ... + def final_states(self) -> List[int]: + """Get all final states.""" + ... def get_index_dict(self) -> Dict[int, Dict[int, int]]: """Returns the Index as a Python Dict object.""" ... diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 84b8746d..944ca150 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -93,6 +93,15 @@ impl PyIndex { }) } + #[staticmethod] + fn from_regex(py: Python<'_>, regex: &str, vocabulary: &PyVocabulary) -> PyResult { + py.allow_threads(|| { + Index::from_regex(regex, &vocabulary.0) + .map(PyIndex) + .map_err(Into::into) + }) + } + fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> { Python::with_gil(|py| { let cls = PyModule::import_bound(py, "outlines_core.fsm.outlines_core_rs")? @@ -126,6 +135,10 @@ impl PyIndex { self.0.is_final(state) } + fn final_states(&self) -> FxHashSet { + self.0.final_states().clone() + } + fn get_transitions(&self) -> FxHashMap> { self.0.transitions().clone() } @@ -291,6 +304,15 @@ impl PyVocabulary { PyVocabulary(Vocabulary::from(map)) } + #[staticmethod] + fn from_dict_with_eos_token_id( + map: FxHashMap>, + eos_token_id: TokenId, + ) -> PyVocabulary { + let v = Vocabulary::from(map).with_eos_token_id(Some(eos_token_id)); + PyVocabulary(v) + } + fn __repr__(&self) -> String { format!("{:#?}", self.0) }