From 130a388dac09e6bdaefca1e27941348ddaee953c Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 10 Dec 2024 18:28:33 +0000 Subject: [PATCH] Allow threads on Index init --- benchmarks/bench_regex_guide.py | 25 +++++++++++++++++++++++++ pyproject.toml | 1 + src/python_bindings/mod.rs | 9 ++++++--- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index a47adae2..5dda576a 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,3 +1,6 @@ +from concurrent.futures import ThreadPoolExecutor + +import psutil from outlines_core.fsm.guide import RegexGuide from .common import setup_tokenizer @@ -25,6 +28,28 @@ def setup(self, pattern_name): def time_regex_to_guide(self, pattern_name): RegexGuide.from_regex(self.pattern, self.tokenizer) + def time_regex_to_guide_parallel(self, pattern_name): + # Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks, + # this parallel case should be relatively close in runtime to one thread, but it is not, + # because of the GIL. + core_count = psutil.cpu_count(logical=False) + with ThreadPoolExecutor(max_workers=core_count) as executor: + list(executor.map(self._from_regex, [pattern_name] * core_count)) + + def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): + # This test is to show, that if GIL's switch interval is set to be longer, then the parallel + # test's runtime on physical cores will be much closer to the one-threaded case. + import sys + + sys.setswitchinterval(5) + + core_count = psutil.cpu_count(logical=False) + with ThreadPoolExecutor(max_workers=core_count) as executor: + list(executor.map(self._from_regex, [pattern_name] * core_count)) + + def _from_regex(self, pattern_name): + RegexGuide.from_regex(self.pattern, self.tokenizer) + class MemoryRegexGuideBenchmark: params = ["simple_phone", "complex_span_constrained_relation_extraction"] diff --git a/pyproject.toml b/pyproject.toml index 57090988..3bde97b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ test = [ "datasets", "pillow", "asv", + "psutil", "setuptools-rust", ] diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 76a9a1ec..84b8746d 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -80,14 +80,17 @@ pub struct PyIndex(Index); impl PyIndex { #[new] fn new( + py: Python<'_>, fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, eos_token_id: u32, frozen_tokens: FxHashSet, ) -> PyResult { - Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) - .map(PyIndex) - .map_err(Into::into) + py.allow_threads(|| { + Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) + .map(PyIndex) + .map_err(Into::into) + }) } fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> {