From de00a4d17a17f1b92e0f112cbfacd21fe9fe1a27 Mon Sep 17 00:00:00 2001 From: kc611 Date: Sun, 21 Jul 2024 15:27:22 +0530 Subject: [PATCH] Replace basic Numba FSM functions with Rust implementations --- .github/workflows/tests.yml | 4 +- .gitignore | 3 + .pre-commit-config.yaml | 24 ++ Cargo.lock | 171 ++++++++ Cargo.toml | 22 + benchmarks/asv.conf.json | 1 + benchmarks/bench_json_schema.py | 3 +- benchmarks/bench_numba_compile.py | 32 -- benchmarks/bench_regex_guide.py | 4 +- benchmarks/common.py | 9 +- pyproject.toml | 26 +- {src => python}/outlines_core/__init__.py | 0 {src => python}/outlines_core/fsm/__init__.py | 0 {src => python}/outlines_core/fsm/fsm.py | 0 {src => python}/outlines_core/fsm/guide.py | 1 - .../outlines_core/fsm/json_schema.py | 0 {src => python}/outlines_core/fsm/regex.py | 380 ++---------------- {src => python}/outlines_core/fsm/types.py | 0 .../outlines_core/integrations/utils.py | 0 .../outlines_core/models/__init__.py | 0 .../outlines_core/models/tokenizer.py | 0 .../outlines_core/models/transformers.py | 1 - {src => python}/outlines_core/py.typed | 0 setup.cfg | 2 +- setup.py | 20 + src/lib.rs | 23 ++ src/regex.rs | 311 ++++++++++++++ tests/fsm/test_fsm.py | 1 - tests/fsm/test_guide.py | 1 - tests/fsm/test_json_schema.py | 3 +- tests/fsm/test_regex.py | 127 ++---- tests/fsm/test_types.py | 1 - tests/models/test_tokenizer.py | 1 - tests/models/test_transformers.py | 3 +- 34 files changed, 680 insertions(+), 494 deletions(-) create mode 100644 Cargo.lock create mode 100644 Cargo.toml delete mode 100644 benchmarks/bench_numba_compile.py rename {src => python}/outlines_core/__init__.py (100%) rename {src => python}/outlines_core/fsm/__init__.py (100%) rename {src => python}/outlines_core/fsm/fsm.py (100%) rename {src => python}/outlines_core/fsm/guide.py (99%) rename {src => python}/outlines_core/fsm/json_schema.py (100%) rename {src => python}/outlines_core/fsm/regex.py (66%) rename {src => python}/outlines_core/fsm/types.py (100%) rename {src => python}/outlines_core/integrations/utils.py (100%) rename {src => python}/outlines_core/models/__init__.py (100%) rename {src => python}/outlines_core/models/tokenizer.py (100%) rename {src => python}/outlines_core/models/transformers.py (99%) rename {src => python}/outlines_core/py.typed (100%) create mode 100644 setup.py create mode 100644 src/lib.rs create mode 100644 src/regex.rs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d24534c..390db4fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -35,7 +35,7 @@ jobs: pip install .[test] - name: Run tests run: | - pytest --cov=src/outlines_core + pytest --cov=outlines_core - name: Upload coverage data uses: actions/upload-artifact@v3 with: diff --git a/.gitignore b/.gitignore index 9add6d8c..1f7f423e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,7 @@ docs/build .idea/ *.gguf .venv +build/ benchmarks/results +target/ +*.so diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9039b60..98ef1b34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,3 +31,27 @@ repos: args: [--allow-redefinition] exclude: ^examples/ additional_dependencies: [types-tqdm, types-Pillow] +- repo: local + hooks: + - id: cargo-fmt + name: cargo-fmt + description: Format files with cargo fmt. + entry: cargo fmt + language: system + types: [rust] + args: ["--"] + - id: cargo-check + name: cargo-check + description: Check files with cargo check. + entry: cargo check + language: system + types: [rust] + pass_filenames: false + - id: cargo-clippy + name: cargo-clippy + description: Check files with cargo clippy + entry: cargo clippy + language: system + args: ["--", "-D", "warnings"] + types: [rust] + pass_filenames: false diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..b4d6172f --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,171 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "libc" +version = "0.2.158" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "portable-atomic" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rust_fsm" +version = "0.1.0" +dependencies = [ + "pyo3", +] + +[[package]] +name = "syn" +version = "2.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..48c6d076 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rust_fsm" +version = "0.1.0" +edition = "2021" + +[lib] +name = "rust_fsm" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.22.0", features = ["extension-module"] } + +[profile.release] +opt-level = 3 +lto = true +codegen-units = 1 +strip = true +panic = 'abort' + +[features] +default = [] +e2e_experimental = [] diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 3dc3f620..3959e2f0 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -7,6 +7,7 @@ "HEAD" ], "build_command": [ + "pip install setuptools_rust", "python -mpip install .[test]", "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}", ], diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 6e5b9aa8..47578cd3 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,7 +1,7 @@ from outlines_core.fsm.guide import RegexGuide from outlines_core.fsm.json_schema import build_regex_from_schema -from .common import ensure_numba_compiled, setup_tokenizer # noqa: E402 +from .common import setup_tokenizer # noqa: E402 simple_schema = """{ "$defs": { @@ -68,7 +68,6 @@ class JsonSchemaBenchmark: def setup(self, schema_name): self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] - ensure_numba_compiled(self.tokenizer) def time_json_schema_to_regex(self, schema_name): build_regex_from_schema(self.schema) diff --git a/benchmarks/bench_numba_compile.py b/benchmarks/bench_numba_compile.py deleted file mode 100644 index ae62702a..00000000 --- a/benchmarks/bench_numba_compile.py +++ /dev/null @@ -1,32 +0,0 @@ -import importlib - -import interegular -import numba - -from outlines_core.fsm import regex - -from .common import setup_tokenizer - - -class NumbaCompileBenchmark: - def setup(self): - self.tokenizer = setup_tokenizer() - self.regex = regex - original_njit = numba.njit - - def mock_njit(*args, **kwargs): - kwargs["cache"] = False - return original_njit(*args, **kwargs) - - self.original_njit = original_njit - numba.njit = mock_njit - importlib.reload(self.regex) - self.regex_pattern, _ = self.regex.make_deterministic_fsm( - interegular.parse_pattern("a").to_fsm().reduce() - ) - - def teardown(self): - numba.njit = self.original_njit - - def time_compile_numba(self): - self.regex.create_fsm_index_tokenizer(self.regex_pattern, self.tokenizer) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index 5d505a48..287d5f51 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,6 +1,6 @@ from outlines_core.fsm.guide import RegexGuide -from .common import ensure_numba_compiled, setup_tokenizer +from .common import setup_tokenizer regex_samples = { "email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", @@ -20,7 +20,6 @@ class RegexGuideBenchmark: def setup(self, pattern_name): self.tokenizer = setup_tokenizer() - ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] def time_regex_to_guide(self, pattern_name): @@ -32,7 +31,6 @@ class MemoryRegexGuideBenchmark: def setup(self, pattern_name): self.tokenizer = setup_tokenizer() - ensure_numba_compiled(self.tokenizer) self.pattern = regex_samples[pattern_name] def peakmem_regex_to_guide(self, pattern_name): diff --git a/benchmarks/common.py b/benchmarks/common.py index db25593d..aee9cfa5 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -1,14 +1,7 @@ -from transformers import AutoTokenizer - -from outlines_core.fsm.guide import RegexGuide from outlines_core.models.transformers import TransformerTokenizer +from transformers import AutoTokenizer def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("gpt2") return TransformerTokenizer(tokenizer) - - -def ensure_numba_compiled(tokenizer): - RegexGuide("a", tokenizer) - return True diff --git a/pyproject.toml b/pyproject.toml index fdaa0500..c23125e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2", "setuptools-rust"] build-backend = "setuptools.build_meta" [project] @@ -29,7 +29,6 @@ dependencies = [ "cloudpickle", "diskcache", "pydantic>=2.0", - "numba", "referencing", "jsonschema", "tqdm", @@ -66,24 +65,25 @@ content-type = "text/markdown" [tool.setuptools] packages = ["outlines_core"] -package-dir = {"" = "src"} +package-dir = {"" = "python"} [tool.setuptools.package-data] "outlines" = ["py.typed"] [tool.setuptools_scm] -write_to = "src/outlines_core/_version.py" +write_to = "python/outlines_core/_version.py" [tool.pytest.ini_options] testpaths = ["tests"] filterwarnings = [ "error", - "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", "ignore::FutureWarning:huggingface_hub.*", "ignore::UserWarning", - "ignore::DeprecationWarning:pyairports.*", +] +addopts = [ + "--import-mode=importlib" ] [tool.mypy] @@ -104,13 +104,16 @@ module = [ "huggingface_hub", "interegular.*", "datasets.*", - "numba.*", + "setuptools.*", + "setuptools_rust.*", + # TODO: Add type info for the Rust extension + "outlines_core.fsm.rust_fsm.*", ] ignore_missing_imports = true [tool.coverage.run] omit = [ - "src/outlines_core/_version.py", + "python/outlines_core/_version.py", "tests/*", ] branch = true @@ -126,6 +129,13 @@ exclude_lines = [ ] show_missing = true +[tool.coverage.paths] +source = [ + "outlines_core", + "**/site-packages/outlines_core", +] + + [tool.diff_cover] compare_branch = "origin/main" diff_range_notation = ".." diff --git a/src/outlines_core/__init__.py b/python/outlines_core/__init__.py similarity index 100% rename from src/outlines_core/__init__.py rename to python/outlines_core/__init__.py diff --git a/src/outlines_core/fsm/__init__.py b/python/outlines_core/fsm/__init__.py similarity index 100% rename from src/outlines_core/fsm/__init__.py rename to python/outlines_core/fsm/__init__.py diff --git a/src/outlines_core/fsm/fsm.py b/python/outlines_core/fsm/fsm.py similarity index 100% rename from src/outlines_core/fsm/fsm.py rename to python/outlines_core/fsm/fsm.py diff --git a/src/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py similarity index 99% rename from src/outlines_core/fsm/guide.py rename to python/outlines_core/fsm/guide.py index 5bfdf81b..8f7250ef 100644 --- a/src/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -13,7 +13,6 @@ import interegular import torch - from outlines_core.fsm.regex import ( create_fsm_index_tokenizer, make_byte_level_fsm, diff --git a/src/outlines_core/fsm/json_schema.py b/python/outlines_core/fsm/json_schema.py similarity index 100% rename from src/outlines_core/fsm/json_schema.py rename to python/outlines_core/fsm/json_schema.py diff --git a/src/outlines_core/fsm/regex.py b/python/outlines_core/fsm/regex.py similarity index 66% rename from src/outlines_core/fsm/regex.py rename to python/outlines_core/fsm/regex.py index 3c06790a..939f3d38 100644 --- a/src/outlines_core/fsm/regex.py +++ b/python/outlines_core/fsm/regex.py @@ -1,12 +1,13 @@ import re -from collections import namedtuple from functools import lru_cache from typing import ( TYPE_CHECKING, Dict, FrozenSet, Generator, + Iterable, List, + Optional, Sequence, Set, Tuple, @@ -14,8 +15,6 @@ cast, ) -import numba -import numpy as np from interegular.fsm import ( FSM, Alphabet, @@ -25,8 +24,15 @@ _AnythingElseCls, anything_else, ) -from numba.typed.typedobjectutils import _nonoptional -from tqdm import tqdm + +from .rust_fsm import ( # noqa: F401 + FSMInfo, + _walk_fsm, + create_fsm_index_end_to_end, + get_token_transition_keys, + get_vocabulary_transition_keys, + state_scan_tokens, +) if TYPE_CHECKING: from outlines_core.models.tokenizer import Tokenizer @@ -47,7 +53,6 @@ def copy(self): class BetterFSM(FSM): flat_transition_map: Dict[Tuple[int, int], int] - trans_key_to_states: Dict[int, List[int]] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -56,13 +61,10 @@ def __init__(self, *args, **kwargs): self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) flat_transition_map = {} - trans_key_to_states = {} for from_state, trans_map in self.map.items(): for trans_key, to_state in trans_map.items(): flat_transition_map[(from_state, trans_key)] = to_state - trans_key_to_states.setdefault(trans_key, set()).add(from_state) - self.__dict__["trans_key_to_states"] = trans_key_to_states self.__dict__["flat_transition_map"] = flat_transition_map self.__dict__["_fsm_info"] = None @@ -79,95 +81,23 @@ def copy(self): @property def fsm_info(self): if self._fsm_info is None: - flat_transition_map_items = np.fromiter( - ((a[0], a[1], b) for a, b in self.flat_transition_map.items()), - dtype=np.dtype("int64, int64, int64"), - ) - trans_key_to_states_items = np.fromiter( - ((k, z) for k, v in self.trans_key_to_states.items() for z in v), - dtype=np.dtype("int64, int64"), - ) - alphabet_symbol_mapping_items = [ - (k, v) - for k, v in self.alphabet._symbol_mapping.items() - if k != anything_else - ] - nb_finals = np.fromiter(self.finals, dtype=np.dtype("int64")) - self.__dict__["_fsm_info"] = create_fsm_info( + anything_value = self.alphabet.anything_value + self.__dict__["_fsm_info"] = FSMInfo( self.initial, - nb_finals, - flat_transition_map_items, - trans_key_to_states_items, - self.alphabet.anything_value, - alphabet_symbol_mapping_items, + self.finals, + self.flat_transition_map, + anything_value, + # TODO FIXME: Perform this conversion in Rust? + { + k: v + for k, v in self.alphabet._symbol_mapping.items() + if k != anything_else + }, ) return self._fsm_info -nb_int_list_type = numba.types.ListType(numba.int64) -nb_int_pair_type = numba.types.UniTuple(numba.int64, 2) -nb_unicode_type = numba.types.unicode_type - - -@numba.njit(cache=True) -def create_fsm_info( - py_initial, - py_finals, - flat_transition_map_items, - trans_key_to_states_items, - py_anything_value, - alphabet_symbol_mapping_items, -): - trans_key_to_states = numba.typed.Dict.empty(numba.int64, nb_int_list_type) - for trans_key_and_state in trans_key_to_states_items: - trans_key_to_states.setdefault( - trans_key_and_state[0], numba.typed.List.empty_list(numba.int64) - ).append(trans_key_and_state[1]) - - flat_transition_map = numba.typed.Dict.empty(nb_int_pair_type, numba.int64) - for trans_key_and_state in flat_transition_map_items: - flat_transition_map[ - (trans_key_and_state[0], trans_key_and_state[1]) - ] = trans_key_and_state[2] - - # use 2-char strings so that we can represent incomplete utf-8 sequences - # as 2-hex-digit pairs - alphabet_symbol_map = numba.typed.Dict.empty(nb_unicode_type, numba.int64) - for symbol_and_trans_key in alphabet_symbol_mapping_items: - alphabet_symbol_map[symbol_and_trans_key[0]] = symbol_and_trans_key[1] - - initial = numba.int64(py_initial) - - finals = set() - for final in py_finals: - finals.add(final) - - anything_value = numba.int64(py_anything_value) - - return FSMInfo( - initial, - finals, - flat_transition_map, - trans_key_to_states, - anything_value, - alphabet_symbol_map, - ) - - -FSMInfo = namedtuple( - "FSMInfo", - [ - "initial", - "finals", - "transitions", - "trans_key_to_states", - "alphabet_anything_value", - "alphabet_symbol_mapping", - ], -) - - TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"] @@ -425,43 +355,6 @@ def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: return new_fsm, old_to_new_states -@numba.njit(nogil=True, cache=True) -def _walk_fsm( - fsm_transitions: Dict[Tuple[int, int], int], - fsm_initial: int, - fsm_finals: Set[int], - token_transition_keys: Sequence[int], - start_state: int, - full_match: bool = True, -) -> List[int]: - state = start_state - accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) - last_final_idx: int = numba.uint64(0) - - # Iterate over token transition key sequence. The transition key - # sequence represents the FSM traversal rules of the tokens symbols. - for i, trans_key in enumerate(token_transition_keys): - new_state = fsm_transitions.get((state, trans_key)) - - if new_state is None: - if not full_match and last_final_idx > 0: - return accepted_states[:last_final_idx] - - return numba.typed.List.empty_list(numba.int64) - - state = new_state - - if state in fsm_finals: - last_final_idx = numba.uint64(i + 1) - - accepted_states.append(_nonoptional(state)) - - if full_match and last_final_idx - 1 != i: - return numba.typed.List.empty_list(numba.int64) - - return accepted_states - - def walk_fsm( fsm: BetterFSM, token_transition_keys: Sequence[int], @@ -657,196 +550,6 @@ def get_sub_fsms_from_seq( ) -@numba.njit(cache=True, nogil=True) -def state_scan_tokens( - fsm_transitions: Dict[Tuple[int, int], int], - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - fsm_initial: int, - fsm_finals: Set[int], - vocabulary: List[Tuple[str, Sequence[int]]], - vocabulary_transition_keys: List[Sequence[int]], - start_state: int, -) -> Set[Tuple[int, int]]: - res = set() - - for (token, token_ids), token_transition_keys in zip( - vocabulary, vocabulary_transition_keys - ): - state_seq = _walk_fsm( - fsm_transitions, - fsm_initial, - fsm_finals, - token_transition_keys, - start_state, - False, - ) - - if state_seq is not None and len(state_seq) < len(token_transition_keys): - continue - - for token_id in token_ids: - res.add((token_id, state_seq[-1])) - - return res - - -@numba.njit(cache=True, nogil=True) -def get_token_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - token_str: str, -) -> Sequence[int]: - """ - Get the sequence of transition keys for an individual string - with respect to an FSMs alphabet symbol mapping - - This requires parsing the null-byte prefix rules of a byte-fsm: - - If two characters are prefixed by \x00, they are the grouped as a hex-byte - - Otherwise they are a standalone utf-8 character - """ - token_transition_keys = [] - i = 0 - while i < len(token_str): - if token_str[i] == "\x00" and i != len(token_str) - 1: - symbol = token_str[i : i + 3] - i += 3 - else: - symbol = token_str[i] - i += 1 - - token_transition_keys.append( - alphabet_symbol_mapping.get(symbol, alphabet_anything_value) - ) - - token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64) - for j in range(len(token_transition_keys)): - token_transition_keys_array[j] = token_transition_keys[j] - return token_transition_keys_array - - -@numba.njit(cache=True, nogil=True) -def get_vocabulary_transition_keys( - alphabet_symbol_mapping: Dict[str, int], - alphabet_anything_value: int, - vocabulary: List[Tuple[str, Sequence[int]]], - frozen_tokens: List[str] = numba.typed.List.empty_list(numba.types.unicode_type), -) -> List[Sequence[int]]: - """ - Calculate the sequence transition keys for each token str within a vocabulary - - Parameters - ---------- - alphabet_symbol_mapping: (`Dict[str, int]`): - A mapping from an alphabet symbol in a FSM to its corresponding transition key. - alphabet_anything_value: (`int`): - The transition key for the anything_else symbol in the FSM. - vocabulary: (`List[Tuple[str, Sequence[int]]]`): - A list of tuples, each containing a token and a list of equivalent token ids. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that are kept as-is when transforming the FSM. - Defaults to an empty list. - - Returns - ------- - `List[Sequence[int]]`: - A list of token transition keys for each token in the vocabulary. - """ - vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:]) - for token_str, _ in vocabulary: - # Since these tokens are not expanded into byte-level transitions, we can - # simply get their transition keys directly. - if token_str in frozen_tokens: - token_transition_keys = np.array( - [alphabet_symbol_mapping[token_str]], dtype=np.int64 - ) - else: - token_transition_keys = get_token_transition_keys( - alphabet_symbol_mapping, alphabet_anything_value, token_str - ) - vocab_transition_keys.append(token_transition_keys) - - return vocab_transition_keys - - -def create_fsm_index_end_to_end( - fsm_info: FSMInfo, - vocabulary: List[Tuple[str, Sequence[int]]], - frozen_tokens: List[str] = [], -) -> Dict[int, Set[Tuple[int, int]]]: - """Create an FSM state-to-vocabulary map/index through end-to-end token parsing. - - Parameters - ---------- - fsm_info: (`interegular.FSMInfo`): - The FSM information object containing the FSM's alphabet, transitions, initial - and final states, and other relevant information. - vocabulary: (`List[Tuple[str, Sequence[int]]]`): - A list of tuples, each containing a token and a list of equivalent token ids. - frozen_tokens: (`List[str]`, *optional*): - A list of tokens that are kept as-is when transforming the FSM. - - Returns - ------- - `Dict[int, Set[Tuple[int, int]]]`: - A mapping from FSM states to sets of tuples containing token ids and the end - states of the FSM after parsing the token. - """ - - # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this - # code, too. - states_to_token_subsets: Dict[int, Set[Tuple[int, int]]] = {} - seen: Set[int] = set() - next_states = {fsm_info.initial} - - pbar = tqdm( - total=len(set(fsm_info.transitions.values())) - + 1, # all transitions plus initial - desc="Compiling FSM index for all state transitions", - ) - - vocabulary_transition_keys = get_vocabulary_transition_keys( - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - vocabulary, - frozen_tokens=( - numba.typed.List(frozen_tokens) - if len(frozen_tokens) > 0 - else numba.typed.List.empty_list(numba.types.unicode_type) - ), - ) - - while next_states: - start_state = next_states.pop() - - token_ids_end_states = state_scan_tokens( - fsm_info.transitions, - fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - fsm_info.initial, - fsm_info.finals, - vocabulary, - vocabulary_transition_keys, - start_state, - ) - - for token_id_and_end_state in token_ids_end_states: - states_to_token_subsets.setdefault(start_state, set()).add( - token_id_and_end_state - ) - end_state = token_id_and_end_state[1] - if end_state not in seen: - next_states.add(end_state) - - if start_state not in seen: - pbar.update(1) - seen.add(start_state) - - pbar.close() - - return states_to_token_subsets - - re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") # The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*" @@ -887,15 +590,15 @@ def gpt2_unicode_to_bytes(): return {v: k for k, v in gpt2_bytes_to_unicode().items()} -# TODO: Cannot cache typed collections to disk, yet. See -# https://github.com/numba/numba/issues/4698 @lru_cache def reduced_vocabulary( tokenizer: "Tokenizer", -) -> Tuple[List[Tuple[str, Sequence[int]]], Set[int]]: +) -> Tuple[Dict[str, List[int]], Set[int]]: """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" + # TODO FIXME: See if we can get the underlying Rust tokenizers from HF and + # do all this in Rust empty_token_ids = set() - vocabulary: Dict[Union[str, Tuple[str, ...]], List[int]] = {} + vocabulary: Dict[str, List[int]] = {} for token, token_idx in tokenizer.vocabulary.items(): if token in tokenizer.special_tokens: continue @@ -927,27 +630,19 @@ def reduced_vocabulary( ) token_str = "".join(byte_symbol(b) for b in token_bytes) + assert isinstance(token_str, str) + vocabulary.setdefault(token_str, []).append(token_idx) else: - empty_token_ids.add(numba.int64(token_idx)) + empty_token_ids.add(token_idx) - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - nb_unicode_type, - numba.int64[:], - ) - ) - ) - for token_str, token_ids in vocabulary.items(): - token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) - vocabulary_nb.append((token_str, token_ids_np)) - - return vocabulary_nb, empty_token_ids + return vocabulary, empty_token_ids def create_fsm_index_tokenizer( - fsm: BetterFSM, tokenizer: "Tokenizer", frozen_tokens: List[str] = [] + fsm: BetterFSM, + tokenizer: "Tokenizer", + frozen_tokens: Optional[Iterable[str]] = None, ) -> Tuple[Dict[int, Dict[int, int]], Set[int]]: """Construct an FMS index from a tokenizer. @@ -980,7 +675,9 @@ def create_fsm_index_tokenizer( vocabulary, empty_token_ids = reduced_vocabulary(tokenizer) states_to_token_subsets = create_fsm_index_end_to_end( - fsm.fsm_info, vocabulary, frozen_tokens + fsm.fsm_info, + list(vocabulary.items()), + frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(), ) # Allow transitions to EOS from all terminals FSM states that are @@ -989,9 +686,6 @@ def create_fsm_index_tokenizer( for state in fsm.fsm_info.finals: subset = states_to_token_subsets.get(state) if subset is not None: - subset.add((tokenizer.eos_token_id, state)) - - # Convert to token-to-end-state maps - states_to_token_subsets = {k: dict(v) for k, v in states_to_token_subsets.items()} + subset[tokenizer.eos_token_id] = state return states_to_token_subsets, empty_token_ids diff --git a/src/outlines_core/fsm/types.py b/python/outlines_core/fsm/types.py similarity index 100% rename from src/outlines_core/fsm/types.py rename to python/outlines_core/fsm/types.py diff --git a/src/outlines_core/integrations/utils.py b/python/outlines_core/integrations/utils.py similarity index 100% rename from src/outlines_core/integrations/utils.py rename to python/outlines_core/integrations/utils.py diff --git a/src/outlines_core/models/__init__.py b/python/outlines_core/models/__init__.py similarity index 100% rename from src/outlines_core/models/__init__.py rename to python/outlines_core/models/__init__.py diff --git a/src/outlines_core/models/tokenizer.py b/python/outlines_core/models/tokenizer.py similarity index 100% rename from src/outlines_core/models/tokenizer.py rename to python/outlines_core/models/tokenizer.py diff --git a/src/outlines_core/models/transformers.py b/python/outlines_core/models/transformers.py similarity index 99% rename from src/outlines_core/models/transformers.py rename to python/outlines_core/models/transformers.py index e219d8a4..bc5ba7b6 100644 --- a/src/outlines_core/models/transformers.py +++ b/python/outlines_core/models/transformers.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union from datasets.fingerprint import Hasher - from outlines_core.models.tokenizer import Tokenizer if TYPE_CHECKING: diff --git a/src/outlines_core/py.typed b/python/outlines_core/py.typed similarity index 100% rename from src/outlines_core/py.typed rename to python/outlines_core/py.typed diff --git a/setup.cfg b/setup.cfg index 3eced887..a6b17955 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,4 +5,4 @@ ignore = E203,E231,E501,E741,W503,W504,C901,E731 per-file-ignores = **/__init__.py:F401,F403 exclude = - normalai/_version.py + outlines_core/_version.py diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..78b9421e --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +import os + +from setuptools import setup +from setuptools_rust import Binding, RustExtension + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + + +rust_extensions = [ + RustExtension( + "outlines_core.fsm.rust_fsm", + f"{CURRENT_DIR}/Cargo.toml", + binding=Binding.PyO3, + rustc_flags=["--crate-type=cdylib"], + ), +] + +setup( + rust_extensions=rust_extensions, +) diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..222c7265 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,23 @@ +mod regex; + +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; +use regex::_walk_fsm; +use regex::create_fsm_index_end_to_end; +use regex::get_token_transition_keys; +use regex::get_vocabulary_transition_keys; +use regex::state_scan_tokens; +use regex::FSMInfo; + +#[pymodule] +fn rust_fsm(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; + + m.add_class::()?; + + Ok(()) +} diff --git a/src/regex.rs b/src/regex.rs new file mode 100644 index 00000000..320346c3 --- /dev/null +++ b/src/regex.rs @@ -0,0 +1,311 @@ +use pyo3::prelude::*; + +use pyo3::types::PyDict; +use std::collections::{HashMap, HashSet}; + +pub fn _walk_fsm_internal( + fsm_transitions: &HashMap<(u32, u32), u32>, + _fsm_initial: u32, + fsm_finals: &HashSet, + token_transition_keys: &[u32], + start_state: u32, + full_match: bool, +) -> Vec { + let mut state = start_state; + let mut accepted_states = Vec::new(); + let mut last_final_idx = 0; + + for (i, &trans_key) in token_transition_keys.iter().enumerate() { + match fsm_transitions.get(&(state, trans_key)) { + Some(&new_state) => { + state = new_state; + if fsm_finals.contains(&state) { + last_final_idx = i + 1; + } + accepted_states.push(state); + } + None => { + if !full_match && last_final_idx > 0 { + return accepted_states[..last_final_idx].to_vec(); + } + return Vec::new(); + } + } + } + + if full_match && last_final_idx != token_transition_keys.len() { + return Vec::new(); + } + + accepted_states +} + +pub fn state_scan_tokens_internal( + fsm_transitions: &HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: &HashSet, + vocabulary: &[(String, Vec)], + vocabulary_transition_keys: &[Vec], + start_state: u32, +) -> HashSet<(u32, u32)> { + let mut res = HashSet::new(); + + for (vocab_item, token_transition_keys) in + vocabulary.iter().zip(vocabulary_transition_keys.iter()) + { + let token_ids: Vec = vocab_item.1.clone(); + + let state_seq = _walk_fsm_internal( + fsm_transitions, + fsm_initial, + fsm_finals, + token_transition_keys, + start_state, + false, + ); + + if state_seq.len() < token_transition_keys.len() { + continue; + } + + for &token_id in &token_ids { + res.insert((token_id, *state_seq.last().unwrap())); + } + } + + res +} + +pub fn get_token_transition_keys_internal( + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: u32, + token_str: &str, +) -> Vec { + let mut token_transition_keys = Vec::new(); + let mut i = 0; + let chars: Vec = token_str.chars().collect(); + + while i < chars.len() { + let symbol; + if chars[i] == '\0' && i != chars.len() - 1 { + if i + 2 < chars.len() { + symbol = format!("\0{}{}", chars[i + 1], chars[i + 2]); + i += 3; + } else { + symbol = chars[i].to_string(); + i += 1; + } + } else { + symbol = chars[i].to_string(); + i += 1; + } + + let transition_key = *alphabet_symbol_mapping + .get(&symbol) + .unwrap_or(&alphabet_anything_value); + token_transition_keys.push(transition_key); + } + + token_transition_keys +} + +pub fn get_vocabulary_transition_keys_internal( + alphabet_symbol_mapping: &HashMap, + alphabet_anything_value: u32, + vocabulary: &[(String, Vec)], + frozen_tokens: &HashSet, +) -> Vec> { + let mut vocab_transition_keys: Vec> = Vec::new(); + + for item in vocabulary.iter() { + let token_str = item.0.clone(); + + let mut token_transition_keys; + + // Since these tokens are not expanded into byte-level transitions, we + // can simply get their transition keys directly. + if frozen_tokens.contains(&token_str) { + token_transition_keys = Vec::new(); + token_transition_keys.push( + *alphabet_symbol_mapping + .get(&token_str) + .unwrap_or(&alphabet_anything_value), + ) + } else { + token_transition_keys = get_token_transition_keys_internal( + alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + ); + } + + vocab_transition_keys.push(token_transition_keys); + } + + vocab_transition_keys +} + +#[pyclass] +pub struct FSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + transitions: HashMap<(u32, u32), u32>, + #[pyo3(get)] + alphabet_anything_value: u32, + #[pyo3(get)] + alphabet_symbol_mapping: HashMap, +} + +#[pymethods] +impl FSMInfo { + #[new] + fn new( + initial: u32, + finals: HashSet, + transitions: HashMap<(u32, u32), u32>, + alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} + +#[pyfunction(name = "_walk_fsm")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" +)] +pub fn _walk_fsm( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + token_transition_keys: Vec, + start_state: u32, + full_match: bool, +) -> PyResult> { + Ok(_walk_fsm_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &token_transition_keys, + start_state, + full_match, + )) +} + +#[pyfunction(name = "state_scan_tokens")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" +)] +pub fn state_scan_tokens( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + vocabulary: Vec<(String, Vec)>, + vocabulary_transition_keys: Vec>, + start_state: u32, +) -> PyResult> { + Ok(state_scan_tokens_internal( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + )) +} + +#[pyfunction(name = "get_token_transition_keys")] +#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] +pub fn get_token_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + token_str: String, +) -> PyResult> { + Ok(get_token_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + )) +} + +#[pyfunction(name = "get_vocabulary_transition_keys")] +#[pyo3( + text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" +)] +pub fn get_vocabulary_transition_keys( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult>> { + Ok(get_vocabulary_transition_keys_internal( + &alphabet_symbol_mapping, + alphabet_anything_value, + &vocabulary, + &frozen_tokens, + )) +} + +#[allow(clippy::too_many_arguments)] +#[pyfunction(name = "create_fsm_index_end_to_end")] +#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] +pub fn create_fsm_index_end_to_end<'py>( + py: Python<'py>, + fsm_info: &FSMInfo, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult> { + let states_to_token_subsets = PyDict::new_bound(py); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + + let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( + &fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + &vocabulary, + &frozen_tokens, + ); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + + // TODO: Return Pydict directly at construction + let token_ids_end_states = state_scan_tokens_internal( + &fsm_info.transitions, + fsm_info.initial, + &fsm_info.finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + ); + + for (token_id, end_state) in token_ids_end_states { + if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { + existing_dict.set_item(token_id, end_state).unwrap(); + } else { + let new_dict = PyDict::new_bound(py); + new_dict.set_item(token_id, end_state).unwrap(); + states_to_token_subsets + .set_item(start_state, new_dict) + .unwrap(); + } + + if !seen.contains(&end_state) { + next_states.insert(end_state); + } + } + + seen.insert(start_state); + } + + Ok(states_to_token_subsets) +} diff --git a/tests/fsm/test_fsm.py b/tests/fsm/test_fsm.py index bb074b0b..aeb7060c 100644 --- a/tests/fsm/test_fsm.py +++ b/tests/fsm/test_fsm.py @@ -1,5 +1,4 @@ import pytest - from outlines_core.fsm.fsm import RegexFSM, StopAtEosFSM diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index c48b1ad9..0bd28d4f 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,5 +1,4 @@ import pytest - from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 12b26912..3fa3d79c 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -4,8 +4,6 @@ import interegular import pytest -from pydantic import BaseModel, Field, constr - from outlines_core.fsm.json_schema import ( BOOLEAN, DATE, @@ -22,6 +20,7 @@ get_schema_from_signature, to_regex, ) +from pydantic import BaseModel, Field, constr def test_function_basic(): diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index ef424156..7b0018bb 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -1,9 +1,6 @@ import interegular -import numba import numpy as np import pytest -from transformers import AutoTokenizer - from outlines_core.fsm.regex import ( _walk_fsm, create_fsm_index_end_to_end, @@ -20,6 +17,7 @@ ) from outlines_core.integrations.utils import adapt_tokenizer from outlines_core.models.transformers import TransformerTokenizer +from transformers import AutoTokenizer def identity(s): @@ -56,7 +54,7 @@ def walk_fsm_from_token_str( ) -def walk_fsm_from_token_str_numba( +def walk_fsm_from_token_str_rust( fsm, input_string: str, start_state: int, @@ -76,7 +74,7 @@ def walk_fsm_from_token_str_numba( "function", [ walk_fsm_from_token_str, - walk_fsm_from_token_str_numba, + walk_fsm_from_token_str_rust, ], ) def test_walk_fsm(function): @@ -126,7 +124,7 @@ def test_walk_fsm(function): "function", [ walk_fsm_from_token_str, - walk_fsm_from_token_str_numba, + walk_fsm_from_token_str_rust, ], ) @pytest.mark.parametrize( @@ -338,29 +336,26 @@ def test_create_fsm_index_end_to_end(): regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) vocabulary = { - "blah": numba.typed.List([0]), - "1a": numba.typed.List([1]), - "2": numba.typed.List([2]), - "0": numba.typed.List([3]), - "": numba.typed.List([4]), + "blah": [0], + "1a": [1], + "2": [2], + "0": [3], + "": [4], } - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) + vocabulary_nb = [] for token_tuple, token_ids in vocabulary.items(): token = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token, token_ids_np)) - res = create_fsm_index_end_to_end(regex_fsm.fsm_info, vocabulary_nb) + res = create_fsm_index_end_to_end( + regex_fsm.fsm_info, + vocabulary_nb, + frozenset(), + ) - assert res == {0: {(2, 2), (3, 1)}, 2: {(2, 2), (3, 2)}} + assert res == {0: {2: 2, 3: 1}, 2: {2: 2, 3: 2}} def test_create_fsm_index_end_to_end_multi_byte(): @@ -371,35 +366,30 @@ def test_create_fsm_index_end_to_end_multi_byte(): byte_fsm = make_byte_level_better_fsm(regex_fsm, keep_utf8=True) vocabulary = { - "blah": numba.typed.List([0]), - "😈a": numba.typed.List([1]), - "πŸ˜‡": numba.typed.List([2]), - "😍": numba.typed.List([3]), - merge_symbols(("F0", "9F", "98", "8D")): numba.typed.List([4]), # '😍' - " 😍": numba.typed.List([5]), - merge_symbols((" ", "F0", "9F", "98", "8D")): numba.typed.List([6]), # ' 😍' - merge_symbols((" ", "F0", "9F", "98")): numba.typed.List( - [7] - ), # ' 😍' incomplete - "": numba.typed.List([8]), + "blah": [0], + "😈a": [1], + "πŸ˜‡": [2], + "😍": [3], + merge_symbols(("F0", "9F", "98", "8D")): [4], # '😍' + " 😍": [5], + merge_symbols((" ", "F0", "9F", "98", "8D")): [6], # ' 😍' + merge_symbols((" ", "F0", "9F", "98")): [7], # ' 😍' incomplete + "": [8], } - vocabulary_nb = numba.typed.List.empty_list( - numba.types.Tuple( - ( - numba.types.unicode_type, - numba.int64[:], - ) - ) - ) + vocabulary_nb = [] for token_tuple, token_ids in vocabulary.items(): token_tuple_np = merge_symbols(token_tuple) token_ids_np = np.fromiter(token_ids, dtype=np.dtype("int64")) vocabulary_nb.append((token_tuple_np, token_ids_np)) - res = create_fsm_index_end_to_end(byte_fsm.fsm_info, vocabulary_nb) + res = create_fsm_index_end_to_end( + byte_fsm.fsm_info, + vocabulary_nb, + frozenset(), + ) - assert res == {0: {(5, 3), (6, 3), (7, 7), (2, 2)}, 3: {(2, 3), (3, 3), (4, 3)}} + assert res == {0: {5: 3, 6: 3, 7: 7, 2: 2}, 3: {2: 3, 3: 3, 4: 3}} @pytest.mark.parametrize( @@ -511,7 +501,6 @@ def test_regex_index_performance(): tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer = TransformerTokenizer(tokenizer) - # Pre-compile Numba functions res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) assert len(res) > 1 @@ -531,11 +520,10 @@ def test_json_index_performance(): import json from enum import Enum + import outlines_core from line_profiler import LineProfiler # type: ignore [import] from pydantic import BaseModel, constr - import outlines_core - class Weapon(str, Enum): sword = "sword" axe = "axe" @@ -599,21 +587,19 @@ def convert_token_to_string(self, token): token_trans_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), + list(vocabulary.items()), + frozenset(), ) token_str_to_tranition_keys = { token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) } # `a` and `b` both are workable, but `z` has distinct transition rules assert interegular_fsm.accepts("zaz") assert interegular_fsm.accepts("zbz") - assert (token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"]).all() - assert not ( - token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] - ).all() + assert token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["b"] + assert not token_str_to_tranition_keys["a"] == token_str_to_tranition_keys["z"] def test_token_trans_keys_walk_fsm(): @@ -637,13 +623,13 @@ def convert_token_to_string(self, token): token_trans_keys = get_vocabulary_transition_keys( regex_fsm.fsm_info.alphabet_symbol_mapping, regex_fsm.fsm_info.alphabet_anything_value, - vocabulary, - numba.typed.List.empty_list(numba.types.unicode_type), + list(vocabulary.items()), + frozenset(), ) token_str_trans_key_seq = { token_str: trans_key_seq - for (token_str, _), trans_key_seq in zip(vocabulary, token_trans_keys) + for (token_str, _), trans_key_seq in zip(vocabulary.items(), token_trans_keys) } # verify initial state valid only for "ab" and "ac" using transition key seq @@ -655,42 +641,13 @@ def convert_token_to_string(self, token): regex_fsm.fsm_info.initial, regex_fsm.fsm_info.finals, token_trans_key_seq, - regex_fsm.fsm_info.initial, + regex_fsm.initial, False, ) is_accepted = len(state_seq) >= len(token_trans_key_seq) assert should_accept == is_accepted -def test_numba_leading_null_byte_UnicodeCharSeq_remains_broken(): - """Assert numba UnicodeCharSeq w/ leading \x00 is still broken""" - # EXPLANATION: - # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.UnicodeCharSeq(1), numba.int64) - d["δΈ€"] = 10 # \xe4\xb8\x80 - with pytest.raises(KeyError): - str(d) - - # most characters are fine, but "\x00" is converted to "" - l = np.fromiter(["\x99", "\x00"], dtype=np.dtype("U2")) - assert str(l[0]) == "\x99" # fine - assert str(l[1]) == "" # 1-byte null converted to 0-bytes - - -@pytest.mark.parametrize("input_key", ["δΈ€", "\x00"]) -def test_numba_leading_null_byte_unicode_type_sane(input_key): - """Assert numba unicode_type w/ leading \x00 is working""" - # EXPLANATION: - # https://github.com/outlines_core-dev/outlines/pull/930#issuecomment-2143535968 - - # from https://github.com/numba/numba/issues/9542 - d = numba.typed.typeddict.Dict.empty(numba.types.unicode_type, numba.int64) - d["δΈ€"] = 10 # \xe4\xb8\x80 - str(d) # assert successfully interprets - - @pytest.mark.parametrize( "rare_token", [ diff --git a/tests/fsm/test_types.py b/tests/fsm/test_types.py index 2102db92..fc66bd3f 100644 --- a/tests/fsm/test_types.py +++ b/tests/fsm/test_types.py @@ -1,7 +1,6 @@ import datetime import pytest - from outlines_core.fsm.types import ( BOOLEAN, DATE, diff --git a/tests/models/test_tokenizer.py b/tests/models/test_tokenizer.py index 95e9cc8f..9457bda5 100644 --- a/tests/models/test_tokenizer.py +++ b/tests/models/test_tokenizer.py @@ -1,5 +1,4 @@ import pytest - from outlines_core.models.tokenizer import Tokenizer diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 8ac8d466..799f7a5b 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -1,10 +1,9 @@ import pytest import torch +from outlines_core.models.transformers import TransformerTokenizer, transformers from transformers import AutoTokenizer from transformers.models.gpt2 import GPT2TokenizerFast -from outlines_core.models.transformers import TransformerTokenizer, transformers - TEST_MODEL = "hf-internal-testing/tiny-random-GPTJForCausalLM"