Skip to content

Commit

Permalink
Add a Rust-first Index object
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored Oct 29, 2024
1 parent 9db9927 commit 34268d6
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
echo "::set-output name=id::$MATRIX_ID"
- name: Run tests
run: |
pytest --cov=outlines_core
pytest --cov=outlines_core -vv
env:
COVERAGE_FILE: .coverage.${{ steps.matrix-id.outputs.id }}
- name: Upload coverage data
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repository = "https://github.com/dottxt-ai/outlines-core"

[dependencies]
anyhow = "1.0.86"
thiserror = "1.0"
pyo3 = { version = "0.22.0", features = ["extension-module"], optional = true }
regex = "1.10.6"
serde-pyobject = "0.4.0"
Expand Down
85 changes: 30 additions & 55 deletions python/outlines_core/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
from typing import Any, Callable, List, Optional, Protocol, Set, Tuple, Union

import interegular
import torch
Expand All @@ -9,6 +9,8 @@
make_deterministic_fsm,
)

from .outlines_core_rs import Index


@dataclass(frozen=True)
class Write:
Expand Down Expand Up @@ -107,7 +109,7 @@ def create_states_mapping(
tokenizer,
regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
) -> Tuple[Index, Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from a regex string.
The parameters of the function are used for caching purpose.
Expand Down Expand Up @@ -143,7 +145,7 @@ def create_states_mapping_from_fsm(
fsm: interegular.fsm.FSM,
tokenizer,
frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
) -> Tuple[Index, Set[int], Set[int]]:
"""Create the variables related to the mapping between states and tokens from an FSM.
The parameters of the function are used for caching purpose.
Expand Down Expand Up @@ -177,16 +179,6 @@ def create_states_mapping_from_fsm(
regex_fsm, tokenizer
)

# We make sure that it is possible to generate strings in the language
# of the regular expression with the tokens present in the model's
# vocabulary.
if not any(
regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
):
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

return states_to_token_maps, empty_token_ids, regex_fsm.finals


Expand All @@ -196,18 +188,12 @@ class RegexGuide(Guide):
initial_state = 0

def __init__(
self,
states_to_token_maps,
empty_token_ids,
fsm_finals,
eos_token_id,
states_to_token_mask,
self, states_to_token_maps, empty_token_ids, eos_tensor, initial_state
):
self.states_to_token_maps = states_to_token_maps
self.empty_token_ids = empty_token_ids
self.eos_token_id = eos_token_id
self.final_states = fsm_finals | {-1}
self.states_to_token_mask = states_to_token_mask
self.eos_tensor = eos_tensor
self.initial_state = initial_state

@classmethod
def from_regex(
Expand All @@ -229,17 +215,9 @@ def from_regex(
regex_parser=regex_parser,
frozen_tokens=frozen_tokens,
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device)
initial_state = states_to_token_maps.get_initial_state()
return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state)

@classmethod
def from_interegular_fsm(
Expand All @@ -257,17 +235,9 @@ def from_interegular_fsm(
) = _create_states_mapping_from_fsm(
interegular_fsm, tokenizer, frozen_tokens=frozen_tokens
)
states_to_token_mask = {
state: torch.tensor(list(next_tokens_to_end_states.keys()), device=device)
for state, next_tokens_to_end_states in states_to_token_maps.items()
}
return cls(
states_to_token_maps,
empty_token_ids,
fsm_finals,
tokenizer.eos_token_id,
states_to_token_mask,
)
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device)
initial_state = states_to_token_maps.get_initial_state()
return cls(states_to_token_maps, empty_token_ids, eos_tensor, initial_state)

def get_next_instruction(self, state: int) -> Instruction:
"""Return the next instruction for guided generation.
Expand All @@ -292,11 +262,14 @@ def get_next_instruction(self, state: int) -> Instruction:
A `Generate` instance that contains the model and the allowed token ids.
"""
next_tokens_mask = self.states_to_token_mask.get(state)
if state == -1:
return Write(self.eos_tensor)
next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state)
# TODO: Create the Write and Generate objects within Rust instead?
if next_tokens_mask is None:
return Write(torch.tensor([self.eos_token_id]))
return Write(self.eos_tensor)

return Generate(next_tokens_mask)
return Generate(torch.tensor(next_tokens_mask))

def get_next_state(self, state: int, token_id: int) -> int:
"""Update the state of the guide.
Expand All @@ -316,19 +289,21 @@ def get_next_state(self, state: int, token_id: int) -> int:
The new state of the guide.
"""
if token_id == self.eos_token_id or state not in self.states_to_token_maps:
if state == -1:
return -1

last_token_to_end_state = self.states_to_token_maps[state]
next_state = last_token_to_end_state.get(token_id)
next_state = self.states_to_token_maps.get_next_state(state, token_id)
if next_state is None:
next_state = -1

return next_state
return -1
else:
return next_state

def is_final_state(self, state: int) -> bool:
"""Determine whether the current state of the guide is a final state."""
return state in self.final_states
return state == -1 or self.states_to_token_maps.is_final_state(state)

def copy(self):
return self

def get_index_dict(self):
"""Returns the Index as a Python Dict object."""
return self.states_to_token_maps.get_index_dict()
17 changes: 17 additions & 0 deletions python/outlines_core/fsm/outlines_core_rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,20 @@ class Vocabulary:
Gets the string representation of the vocabulary.
"""
...

class Index:
def get_allowed_tokens(self, state: int) -> Optional[List[int]]:
"""Returns allowed tokens in this state."""
...
def get_next_state(self, state: int, token_id: int) -> Optional[int]:
"""Updates the state."""
...
def is_final_state(self, state: int) -> bool:
"""Determines whether the current state is a final state."""
...
def get_index_dict(self) -> Dict[int, Dict[int, int]]:
"""Returns the Index as a Python Dict object."""
...
def get_initial_state(self) -> int:
"""Returns the ID of the initial state of the input FSM automata."""
...
14 changes: 4 additions & 10 deletions python/outlines_core/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .outlines_core_rs import ( # noqa: F401
FSMInfo,
Index,
Vocabulary,
_walk_fsm,
create_fsm_index_end_to_end,
Expand Down Expand Up @@ -438,7 +439,7 @@ def create_fsm_index_tokenizer(
fsm: BetterFSM,
tokenizer,
frozen_tokens: Optional[Iterable[str]] = None,
) -> Tuple[Dict[int, Dict[int, int]], Set[int]]:
) -> Tuple[Index, Set[int]]:
"""Construct an FMS index from a tokenizer.
This uses the end-to-end approach of `create_fsm_index_end_to_end`.
Expand Down Expand Up @@ -469,18 +470,11 @@ def create_fsm_index_tokenizer(
"""
tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer)

states_to_token_subsets = create_fsm_index_end_to_end(
states_to_token_subsets = Index( # type: ignore
fsm.fsm_info,
Vocabulary.from_dict(tokens_to_token_ids),
tokenizer.eos_token_id,
frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(),
)

# Allow transitions to EOS from all terminals FSM states that are
# reachable
# TODO: Do we really need this anymore?
for state in fsm.fsm_info.finals:
subset = states_to_token_subsets.get(state)
if subset is not None:
subset[tokenizer.eos_token_id] = state

return states_to_token_subsets, empty_token_ids
132 changes: 132 additions & 0 deletions src/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/// Construct an Index.
use crate::prelude::{State, TransitionKey};
use crate::regex::{get_vocabulary_transition_keys, state_scan_tokens};
use crate::vocabulary::Vocabulary;
use std::collections::{HashMap, HashSet};

pub type Result<T, E = crate::Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub struct FSMInfo {
pub(crate) initial: State,
pub(crate) finals: HashSet<State>,
pub(crate) transitions: HashMap<(State, TransitionKey), State>,
pub(crate) alphabet_anything_value: TransitionKey,
pub(crate) alphabet_symbol_mapping: HashMap<String, TransitionKey>,
}

impl FSMInfo {
pub fn new(
initial: State,
finals: HashSet<State>,
transitions: HashMap<(State, TransitionKey), State>,
alphabet_anything_value: TransitionKey,
alphabet_symbol_mapping: HashMap<String, TransitionKey>,
) -> Self {
Self {
initial,
finals,
transitions,
alphabet_anything_value,
alphabet_symbol_mapping,
}
}
}

#[derive(Debug)]
pub struct Index {
pub(crate) initial: u32,
finals: HashSet<u32>,
states_to_token_subsets: HashMap<u32, HashMap<u32, u32>>,
eos_token_id: u32,
}

impl Index {
pub fn new(
fsm_info: &FSMInfo,
vocabulary: &Vocabulary,
eos_token_id: u32,
frozen_tokens: HashSet<String>,
) -> Result<Self> {
let mut states_to_token_subsets: HashMap<u32, HashMap<u32, u32>> = HashMap::new();
let mut seen: HashSet<State> = HashSet::new();
let mut next_states: HashSet<State> = HashSet::from([fsm_info.initial]);

let vocabulary_transition_keys = get_vocabulary_transition_keys(
&fsm_info.alphabet_symbol_mapping,
fsm_info.alphabet_anything_value,
vocabulary,
&frozen_tokens,
);

while let Some(start_state) = next_states.iter().cloned().next() {
next_states.remove(&start_state);

let token_ids_end_states = state_scan_tokens(
&fsm_info.transitions,
fsm_info.initial,
&fsm_info.finals,
vocabulary,
&vocabulary_transition_keys,
start_state,
);

for (token_id, end_state) in &token_ids_end_states {
let inner_map = states_to_token_subsets.entry(start_state).or_default();
inner_map.insert(*token_id, *end_state);

if !seen.contains(end_state) {
next_states.insert(*end_state);
}
}

if fsm_info.finals.contains(&start_state) && !token_ids_end_states.is_empty() {
let inner_map = states_to_token_subsets.entry(start_state).or_default();
inner_map.insert(eos_token_id, start_state);
}

seen.insert(start_state);
}

let is_valid = states_to_token_subsets
.values()
.flat_map(|token_id_end_states| token_id_end_states.values())
.any(|end_state| fsm_info.finals.contains(end_state));

if is_valid {
Ok(Self {
initial: fsm_info.initial,
finals: fsm_info.finals.clone(),
states_to_token_subsets,
eos_token_id,
})
} else {
Err(crate::Error::IndexError)
}
}

pub(crate) fn allowed_tokens(&self, state: u32) -> Option<Vec<u32>> {
self.states_to_token_subsets
.get(&state)
.map_or_else(|| None, |res| Some(res.keys().cloned().collect()))
}

pub(crate) fn next_state(&self, state: u32, token_id: u32) -> Option<u32> {
if token_id == self.eos_token_id {
return None;
}
Some(*self.states_to_token_subsets.get(&state)?.get(&token_id)?)
}

pub(crate) fn initial(&self) -> u32 {
self.initial
}

pub(crate) fn is_final(&self, state: u32) -> bool {
self.finals.contains(&state)
}

pub(crate) fn index(&self) -> &HashMap<u32, HashMap<u32, u32>> {
&self.states_to_token_subsets
}
}
Loading

0 comments on commit 34268d6

Please sign in to comment.