Skip to content

Commit

Permalink
Replace basic Numba FSM functions with Rust implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 authored and brandonwillard committed Aug 20, 2024
1 parent 002b771 commit de00a4d
Show file tree
Hide file tree
Showing 34 changed files with 680 additions and 494 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.10"]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -35,7 +35,7 @@ jobs:
pip install .[test]
- name: Run tests
run: |
pytest --cov=src/outlines_core
pytest --cov=outlines_core
- name: Upload coverage data
uses: actions/upload-artifact@v3
with:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ docs/build
.idea/
*.gguf
.venv
build/
benchmarks/results
target/
*.so
24 changes: 24 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,27 @@ repos:
args: [--allow-redefinition]
exclude: ^examples/
additional_dependencies: [types-tqdm, types-Pillow]
- repo: local
hooks:
- id: cargo-fmt
name: cargo-fmt
description: Format files with cargo fmt.
entry: cargo fmt
language: system
types: [rust]
args: ["--"]
- id: cargo-check
name: cargo-check
description: Check files with cargo check.
entry: cargo check
language: system
types: [rust]
pass_filenames: false
- id: cargo-clippy
name: cargo-clippy
description: Check files with cargo clippy
entry: cargo clippy
language: system
args: ["--", "-D", "warnings"]
types: [rust]
pass_filenames: false
171 changes: 171 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "rust_fsm"
version = "0.1.0"
edition = "2021"

[lib]
name = "rust_fsm"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.22.0", features = ["extension-module"] }

[profile.release]
opt-level = 3
lto = true
codegen-units = 1
strip = true
panic = 'abort'

[features]
default = []
e2e_experimental = []
1 change: 1 addition & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"HEAD"
],
"build_command": [
"pip install setuptools_rust",
"python -mpip install .[test]",
"PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}",
],
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from outlines_core.fsm.guide import RegexGuide
from outlines_core.fsm.json_schema import build_regex_from_schema

from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402
from .common import setup_tokenizer # noqa: E402

simple_schema = """{
"$defs": {
Expand Down Expand Up @@ -68,7 +68,6 @@ class JsonSchemaBenchmark:
def setup(self, schema_name):
self.tokenizer = setup_tokenizer()
self.schema = schemas[schema_name]
ensure_numba_compiled(self.tokenizer)

def time_json_schema_to_regex(self, schema_name):
build_regex_from_schema(self.schema)
Expand Down
32 changes: 0 additions & 32 deletions benchmarks/bench_numba_compile.py

This file was deleted.

4 changes: 1 addition & 3 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from outlines_core.fsm.guide import RegexGuide

from .common import ensure_numba_compiled, setup_tokenizer
from .common import setup_tokenizer

regex_samples = {
"email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
Expand All @@ -20,7 +20,6 @@ class RegexGuideBenchmark:

def setup(self, pattern_name):
self.tokenizer = setup_tokenizer()
ensure_numba_compiled(self.tokenizer)
self.pattern = regex_samples[pattern_name]

def time_regex_to_guide(self, pattern_name):
Expand All @@ -32,7 +31,6 @@ class MemoryRegexGuideBenchmark:

def setup(self, pattern_name):
self.tokenizer = setup_tokenizer()
ensure_numba_compiled(self.tokenizer)
self.pattern = regex_samples[pattern_name]

def peakmem_regex_to_guide(self, pattern_name):
Expand Down
9 changes: 1 addition & 8 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from transformers import AutoTokenizer

from outlines_core.fsm.guide import RegexGuide
from outlines_core.models.transformers import TransformerTokenizer
from transformers import AutoTokenizer


def setup_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
return TransformerTokenizer(tokenizer)


def ensure_numba_compiled(tokenizer):
RegexGuide("a", tokenizer)
return True
Loading

0 comments on commit de00a4d

Please sign in to comment.