Skip to content

Commit

Permalink
Added PyVocabIndex: PyO3 bindings to an index object
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Oct 21, 2024
1 parent b3b9658 commit c8e03cc
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 84 deletions.
85 changes: 28 additions & 57 deletions python/outlines_core/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
from typing import Any, Callable, List, Optional, Protocol, Set, Tuple, Union

import interegular
import torch
Expand All @@ -9,6 +9,8 @@
make_deterministic_fsm,
)

from .outlines_core_rs import PyVocabIndex


@dataclass(frozen=True)
class Write:
Expand Down Expand Up @@ -107,7 +109,7 @@ def create_states_mapping(
tokenizer,
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
) -> Tuple[PyVocabIndex, Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from a regex string.
The parameters of the function are used for caching purpose.
Expand Down Expand Up @@ -143,7 +145,7 @@ def create_states_mapping_from_fsm(
fsm: interegular.fsm.FSM,
tokenizer,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
) -> Tuple[PyVocabIndex, Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from an FSM.
The parameters of the function are used for caching purpose.
Expand Down Expand Up @@ -177,16 +179,6 @@ def create_states_mapping_from_fsm(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals


Expand All @@ -196,18 +188,12 @@ class RegexGuide(Guide):
initial_state = 0

def __init__(
self,
states_to_token_maps,
empty_token_ids,
fsm_finals,
eos_token_id,
states_to_token_mask,
self, states_to_token_maps, empty_token_ids, eos_tensor, initial_state
):
self.states_to_token_maps = states_to_token_maps
self.empty_token_ids = empty_token_ids
self.eos_token_id = eos_token_id
self.final_states = fsm_finals | {-1}
self.states_to_token_mask = states_to_token_mask
self.eos_tensor = eos_tensor
self.initial_state = initial_state

@classmethod
def from_regex(
Expand All @@ -229,17 +215,9 @@ def from_regex(
regex_parser=regex_parser,
frozen_tokens=frozen_tokens,
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device)
initial_state = states_to_token_maps.get_initial_state()
return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state)

@classmethod
def from_interegular_fsm(
Expand All @@ -257,17 +235,9 @@ def from_interegular_fsm(
) = _create_states_mapping_from_fsm(
interegular_fsm, tokenizer, frozen_tokens=frozen_tokens
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device)
initial_state = states_to_token_maps.get_initial_state()
return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state)

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand All @@ -292,11 +262,14 @@ def get_next_instruction(self, state: int) -> Instruction:
A `Generate` instance that contains the model and the allowed token ids.
"""
next_tokens_mask = self.states_to_token_mask.get(state)
if next_tokens_mask is None:
return Write(torch.tensor([self.eos_token_id]))
if state == -1:
return Write(self.eos_tensor)
next_tokens_mask = self.states_to_token_maps.get_next_instruction(state)
# TODO: Create the Write and Generate objects within Rust instead?
if len(next_tokens_mask) == 0:
return Write(self.eos_tensor)

return Generate(next_tokens_mask)
return Generate(torch.tensor(next_tokens_mask))

def get_next_state(self, state: int, token_id: int) -> int:
"""Update the state of the guide.
Expand All @@ -316,19 +289,17 @@ def get_next_state(self, state: int, token_id: int) -> int:
The new state of the guide.
"""
if token_id == self.eos_token_id or state not in self.states_to_token_maps:
if state == -1:
return -1

last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
if next_state is None:
next_state = -1

return next_state
return self.states_to_token_maps.get_next_state(state, token_id)

def is_final_state(self, state: int) -> bool:
"""Determine whether the current state of the guide is a final state."""
return state in self.final_states
return state == -1 or self.states_to_token_maps.is_final_state(state)

def copy(self):
return self

def get_index_dict(self):
"""Returns the Index as a Python Dict object."""
return self.states_to_token_maps.get_index_dict()
21 changes: 21 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,24 @@ class Vocabulary:
Gets the string representation of the vocabulary.
"""
...

class PyVocabIndex:
def get_next_instruction(self, state: int):
"""
Return the next instruction for guided generation.
"""
...
def get_next_state(self, state: int, token_id: int):
"""
Update the state of the guide.
"""
...
def is_final_state(self, state: int):
"""Determine whether the current state of the guide is a final state."""
...
def get_index_dict(self):
"""Returns the Index as a Python Dict object."""
...
def get_initial_state(self):
"""Returns the ID of the initial state of the input FSM automata."""
...
14 changes: 4 additions & 10 deletions python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .outlines_core_rs import ( # noqa: F401
FSMInfo,
PyVocabIndex,
Vocabulary,
_walk_fsm,
create_fsm_index_end_to_end,
Expand Down Expand Up @@ -438,7 +439,7 @@ def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer,
frozen_tokens: Optional[Iterable[str]] = None,
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
) -> Tuple[PyVocabIndex, Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
Expand Down Expand Up @@ -469,18 +470,11 @@ def create_fsm_index_tokenizer(
"""
tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(
states_to_token_subsets = PyVocabIndex( # type: ignore
fsm.fsm_info,
Vocabulary.from_dict(tokens_to_token_ids),
tokenizer.eos_token_id,
frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(),
)

# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset[tokenizer.eos_token_id] = state

return states_to_token_subsets, empty_token_ids
114 changes: 114 additions & 0 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,119 @@ impl FSMInfo {
}
}

#[pyclass]
pub struct PyVocabIndex {
initial: u32,
finals: HashSet<u32>,
states_to_token_subsets: HashMap<u32, HashMap<u32, u32>>,
#[allow(dead_code)]
eos_token_id: u32,
}

#[pymethods]
impl PyVocabIndex {
#[new]
fn new(
fsm_info: &FSMInfo,
vocabulary: &PyVocabulary,
eos_token_id: u32,
frozen_tokens: HashSet<String>,
) -> PyResult<Self> {
let mut states_to_token_subsets: HashMap<u32, HashMap<u32, u32>> = HashMap::new();
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from_iter(vec![fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
&vocabulary.0,
&frozen_tokens,
);

while let Some(start_state) = next_states.iter().cloned().next() {
next_states.remove(&start_state);

// TODO: Return Pydict directly at construction
let token_ids_end_states = state_scan_tokens(
&fsm_info.transitions,
fsm_info.initial,
&fsm_info.finals,
&vocabulary.0,
&vocabulary_transition_keys,
start_state,
);

for (token_id, end_state) in token_ids_end_states {
let inner_map = states_to_token_subsets.entry(start_state).or_default();
inner_map.insert(token_id, end_state);

if !seen.contains(&end_state) {
next_states.insert(end_state);
}
}

seen.insert(start_state);
}

let mut is_valid = false;
for token_id_end_states in states_to_token_subsets.values() {
for end_state in token_id_end_states.values() {
if fsm_info.finals.contains(end_state) {
is_valid = true;
break;
}
}
if is_valid {
break;
}
}

if is_valid {
Ok(Self {
initial: fsm_info.initial,
finals: fsm_info.finals.clone(),
states_to_token_subsets,
eos_token_id,
})
} else {
Err(PyErr::new::<PyValueError, _>(
"The vocabulary does not allow us to build a sequence that matches the input",
))
}
}

fn get_next_instruction(&mut self, state: u32) -> Vec<u32> {
let default = HashMap::new();
let res = self.states_to_token_subsets.get(&state).unwrap_or(&default);
res.keys().cloned().collect()
}

fn get_next_state(&mut self, state: u32, token_id: u32) -> i32 {
let res = if let Some(token_id_end_states) = self.states_to_token_subsets.get(&state) {
if let Some(&end_state) = token_id_end_states.get(&token_id) {
end_state.try_into().unwrap()
} else {
-1
}
} else {
-1
};
res
}

fn is_final_state(&mut self, state: u32) -> bool {
self.finals.contains(&state)
}

fn get_index_dict(&mut self) -> HashMap<u32, HashMap<u32, u32>> {
self.states_to_token_subsets.clone()
}

fn get_initial_state(&mut self) -> u32 {
self.initial
}
}

#[pyfunction(name = "build_regex_from_schema")]
#[pyo3(signature = (json, whitespace_pattern=None))]
pub fn build_regex_from_schema_py(
Expand Down Expand Up @@ -236,6 +349,7 @@ fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(to_regex_py, m)?)?;

m.add_class::<PyVocabulary>()?;
m.add_class::<PyVocabIndex>()?;

Ok(())
}
Loading

0 comments on commit c8e03cc

Please sign in to comment.