Skip to content

Commit

Permalink
Allow threads on Index init
Browse files Browse the repository at this point in the history
  • Loading branch information
torymur committed Dec 11, 2024
1 parent 3d01211 commit 4bdec21
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
38 changes: 30 additions & 8 deletions benchmarks/bench_regex_guide.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from concurrent.futures import ThreadPoolExecutor

import psutil
from outlines_core.fsm.guide import RegexGuide

from .common import setup_tokenizer
Expand All @@ -14,7 +17,6 @@
"complex_span_constrained_relation_extraction": "(['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?\\s\\|\\s([^|\\(\\)\n]{1,})\\s\\|\\s['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?(\\s\\|\\s\\(([^|\\(\\)\n]{1,})\\s\\|\\s([^|\\(\\)\n]{1,})\\))*\\n)*",
}


class RegexGuideBenchmark:
params = regex_samples.keys()

Expand All @@ -25,13 +27,33 @@ def setup(self, pattern_name):
def time_regex_to_guide(self, pattern_name):
RegexGuide.from_regex(self.pattern, self.tokenizer)

def time_regex_to_guide_parallel(self, pattern_name):
# Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks,
# this parallel case should be relatively close in runtime to one thread, but it is not,
# because of the GIL.
core_count = psutil.cpu_count(logical=False)
with ThreadPoolExecutor(max_workers=core_count) as executor:
list(executor.map(self._from_regex, [pattern_name] * core_count))

def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name):
# This test is to show, that if GIL's switch interval is set to be longer, then the parallel
# test's runtime on physical cores will be much closer to the one-threaded case.
import sys
sys.setswitchinterval(5)

core_count = psutil.cpu_count(logical=False)
with ThreadPoolExecutor(max_workers=core_count) as executor:
list(executor.map(self._from_regex, [pattern_name] * core_count))

def _from_regex(self, pattern_name):
RegexGuide.from_regex(self.pattern, self.tokenizer)

class MemoryRegexGuideBenchmark:
params = ["simple_phone", "complex_span_constrained_relation_extraction"]
class MemoryRegexGuideBenchmark:
params = ["simple_phone", "complex_span_constrained_relation_extraction"]

def setup(self, pattern_name):
self.tokenizer = setup_tokenizer()
self.pattern = regex_samples[pattern_name]
def setup(self, pattern_name):
self.tokenizer = setup_tokenizer()
self.pattern = regex_samples[pattern_name]

def peakmem_regex_to_guide(self, pattern_name):
RegexGuide.from_regex(self.pattern, self.tokenizer)
def peakmem_regex_to_guide(self, pattern_name):
RegexGuide.from_regex(self.pattern, self.tokenizer)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test = [
"datasets",
"pillow",
"asv",
"psutil",
"setuptools-rust",
]

Expand Down
9 changes: 6 additions & 3 deletions src/python_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ pub struct PyIndex(Index);
impl PyIndex {
#[new]
fn new(
py: Python<'_>,
fsm_info: &PyFSMInfo,
vocabulary: &PyVocabulary,
eos_token_id: u32,
frozen_tokens: FxHashSet<String>,
) -> PyResult<Self> {
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
.map(PyIndex)
.map_err(Into::into)
py.allow_threads(|| {
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
.map(PyIndex)
.map_err(Into::into)
})
}

fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {
Expand Down

0 comments on commit 4bdec21

Please sign in to comment.