Skip to content

Commit

Permalink
Rust impl (#1)
Browse files Browse the repository at this point in the history
* fix bug when visiting a state twice is needed

* fix bug when visiting a state twice is needed

* add CI
  • Loading branch information
unaidedelf8777 authored Apr 11, 2024
1 parent 4bc8986 commit f7fc0bd
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 44 deletions.
132 changes: 132 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
name: CI

on:
push:
branches:
- main
- master
tags:
- 'v*'
workflow_dispatch:

permissions:
contents: read

jobs:
linux:
runs-on: ${{ matrix.platform.runner }}
strategy:
matrix:
platform:
- runner: ubuntu-latest
target: x86_64
- runner: ubuntu-latest
target: x86
- runner: ubuntu-latest
target: aarch64
- runner: ubuntu-latest
target: armv7
- runner: ubuntu-latest
target: s390x
- runner: ubuntu-latest
target: ppc64le
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
manylinux: auto
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-linux-${{ matrix.platform.target }}
path: dist

windows:
runs-on: ${{ matrix.platform.runner }}
strategy:
matrix:
platform:
- runner: windows-latest
target: x64
- runner: windows-latest
target: x86
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
architecture: ${{ matrix.platform.target }}
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-windows-${{ matrix.platform.target }}
path: dist

macos:
runs-on: ${{ matrix.platform.runner }}
strategy:
matrix:
platform:
- runner: macos-latest
target: x86_64
- runner: macos-14
target: aarch64
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Build wheels
uses: PyO3/maturin-action@v1
with:
target: ${{ matrix.platform.target }}
args: --release --out dist --find-interpreter
sccache: 'true'
- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-macos-${{ matrix.platform.target }}
path: dist

sdist:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build sdist
uses: PyO3/maturin-action@v1
with:
command: sdist
args: --out dist
- name: Upload sdist
uses: actions/upload-artifact@v4
with:
name: wheels-sdist
path: dist

release:
name: Release
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
needs: [linux, windows, macos, sdist]
steps:
- uses: actions/download-artifact@v4
- name: Publish to PyPI
uses: PyO3/maturin-action@v1
env:
MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
with:
command: upload
args: --non-interactive --skip-existing wheels-*/*
7 changes: 4 additions & 3 deletions function_sampler/fsm/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import interegular

from ..cache import cache
from .regex import create_fsm_index_tokenizer, make_deterministic_fsm
import time

if TYPE_CHECKING:
from .tokenizer_fsm_patch import Tokenizer
Expand All @@ -30,13 +30,13 @@ class RegexFSM(FSM):
"""FSM to generate text that is in the language of a regular expression."""

def __init__(self, regex_string: str, tokenizer):
@cache
def create_states_mapping(
regex_string: str, cacheable_vocabulary: Tuple[Tuple[str, int], ...]
) -> Tuple[dict, set]:
"""Create the variables related to the mapping between states and tokens
The parameters of the function are used for caching purpose
"""
start_time = time.perf_counter()
regex_pattern = interegular.parse_pattern(regex_string)
regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
Expand All @@ -53,7 +53,8 @@ def create_states_mapping(
raise ValueError(
"The vocabulary does not allow us to build a sequence that matches the input regex"
)

end_time = time.perf_counter()
print(f"Time taken for Rust: {end_time - start_time} seconds")
return states_to_token_maps, empty_token_ids

self.states_to_token_maps, self.empty_token_ids = create_states_mapping(
Expand Down
2 changes: 1 addition & 1 deletion function_sampler/fsm/fsm_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod tokenizer_index;

#[pymodule]
#[pyo3(name="fsm_utils")]
fn fsm_utils(py: Python, m: &PyModule) -> PyResult<()> {
fn fsm_utils(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?;
Ok(())
}
64 changes: 28 additions & 36 deletions function_sampler/fsm/fsm_utils/src/tokenizer_index.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict};
use rayon::prelude::*;
use std::collections::{HashMap, HashSet};

#[derive(Debug, Clone)]
#[derive(FromPyObject)]
#[derive(Debug, Clone, FromPyObject)]
pub struct FSMInfo {
#[pyo3(item("initial"))]
initial: i64,
Expand All @@ -23,7 +20,6 @@ pub struct FSMInfo {

pub type TokenVocabulary = HashMap<String, Vec<i64>>;


fn walk_fsm(
fsm_info: &FSMInfo,
input_string: &str,
Expand All @@ -36,14 +32,15 @@ fn walk_fsm(

let mut current_pos = 0;
let input_chars: Vec<char> = input_string.chars().collect();

while current_pos < input_chars.len() {
let mut found = false;

// Attempt to match longer substrings first, ensuring multi-character sequences are prioritized
for len in (1..=input_chars.len() - current_pos).rev() {
let possible_match: String = input_chars[current_pos..current_pos+len].iter().collect();

let possible_match: String =
input_chars[current_pos..current_pos + len].iter().collect();

if let Some(&trans_key) = fsm_info.alphabet_symbol_mapping.get(&possible_match) {
if let Some(&new_state) = fsm_info.transitions.get(&(state, trans_key)) {
state = new_state;
Expand All @@ -61,7 +58,10 @@ fn walk_fsm(
if !found {
if !full_match && last_final_idx.is_some() {
// Non-full match and we've previously encountered a final state
return accepted_states.into_iter().take(last_final_idx.unwrap()).collect();
return accepted_states
.into_iter()
.take(last_final_idx.unwrap())
.collect();
} else {
// No match found, or a full match is required
return vec![];
Expand All @@ -82,19 +82,21 @@ fn state_scan_tokens(
vocabulary: &TokenVocabulary,
start_state: i64,
) -> HashSet<(i64, i64)> {
vocabulary.par_iter()
vocabulary
.par_iter()
.flat_map(|(token, token_ids)| {
// For each token, perform the FSM walk in parallel.
let state_seq = walk_fsm(fsm_info, token, start_state, false);

if state_seq.len() < token.chars().count() {
None
} else {
Some(token_ids.iter()
.map(move |&token_id| {
(token_id, *state_seq.last().unwrap())
})
.collect::<Vec<_>>())
Some(
token_ids
.iter()
.map(move |&token_id| (token_id, *state_seq.last().unwrap()))
.collect::<Vec<_>>(),
)
}
})
// Flatten the nested structure into a single collection of pairs.
Expand All @@ -106,26 +108,27 @@ fn state_scan_tokens(
fn create_fsm_index_end_to_end(
fsm_info: &FSMInfo,
vocabulary: &TokenVocabulary,
) -> HashMap<i64, HashSet<(i64, i64)>> {
) -> HashMap<i64, HashMap<i64, i64>> {
let mut states_to_token_subsets = HashMap::new();
let mut seen = HashSet::new();
let mut next_states = HashSet::new();
next_states.insert(fsm_info.initial);

while let Some(start_state) = next_states.iter().next().copied() {
next_states.remove(&start_state);
let token_ids_end_states = state_scan_tokens(fsm_info, vocabulary, start_state);

for &(token_id, end_state) in &token_ids_end_states {
states_to_token_subsets
.entry(start_state)
.or_insert_with(HashSet::new)
.insert((token_id, end_state));
.or_insert_with(HashMap::new)
.insert(token_id, end_state);
if !seen.contains(&end_state) {
next_states.insert(end_state);
}
}

next_states.remove(&start_state);

seen.insert(start_state);
}

Expand All @@ -144,22 +147,11 @@ fn create_fsm_index_end_to_end(
#[pyo3(text_signature = "(fsm_info, vocabulary, /)")]
pub fn create_fsm_index_end_to_end_py(
py: Python<'_>,
fsm_info_py: FSMInfo,
vocabulary_py: TokenVocabulary,
) -> PyResult<PyObject> {

let fsm_info: FSMInfo = fsm_info_py;
let vocabulary: TokenVocabulary = vocabulary_py;
let states_to_token_subsets = create_fsm_index_end_to_end(&fsm_info, &vocabulary);

let states_to_token_subsets_py = PyDict::new_bound(py); // Assuming new_bound exists and is the correct replacement
for (k, v) in states_to_token_subsets.iter() {
let subset_py = PyDict::new_bound(py); // Adjusted per deprecation notice
for (inner_k, inner_v) in v.iter() {
subset_py.set_item(inner_k, inner_v)?;
}
states_to_token_subsets_py.set_item(k, subset_py)?;
}
fsm_info: FSMInfo,
vocabulary: TokenVocabulary,
) -> HashMap<i64, HashMap<i64, i64>> {
let states_to_token_subsets =
py.allow_threads(move || create_fsm_index_end_to_end(&fsm_info, &vocabulary));

Ok(states_to_token_subsets_py.into())
states_to_token_subsets
}
2 changes: 1 addition & 1 deletion function_sampler/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from interegular.fsm import FSM, Alphabet, anything_else

from .fsm_utils import create_fsm_index_end_to_end
from .utils import reduced_vocabulary

Expand Down Expand Up @@ -213,6 +212,7 @@ def create_fsm_index_tokenizer(
"""
vocabulary, empty_token_ids = reduced_vocabulary(tokenizer)

fsm_info = fsm.fsm_info
# rust impl expects generic types, so just cast them.
states_to_token_subsets = create_fsm_index_end_to_end(fsm_info, dict(vocabulary)) # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions function_sampler/fsm/tokenizer_fsm_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def convert_token_to_string(self, token: str) -> str:
...


SPIECE_UNDERLINE = "\u2581"


class TransformerTokenizer(Tokenizer):
"""Represents a tokenizer for models in the `transformers` library."""

Expand Down Expand Up @@ -73,8 +76,6 @@ def decode(self, token_ids: torch.LongTensor) -> List[str]:
return text

def convert_token_to_string(self, token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE

string = self.tokenizer.convert_tokens_to_string([token])

# A hack to handle missing spaces to HF's Llama tokenizers
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers = [
]

[tool.poetry.dependencies]
python = "^3.10"
python = "^3.8"
transformers = "^4.38.2"
pydantic = "^2.6.3"
diskcache = "^5.6.3"
Expand Down

0 comments on commit f7fc0bd

Please sign in to comment.