From 4bdec212917aea4f12e8454fcd028b7ea80e0082 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Tue, 10 Dec 2024 18:28:33 +0000 Subject: [PATCH] Allow threads on Index init --- benchmarks/bench_regex_guide.py | 38 ++++++++++++++++++++++++++------- pyproject.toml | 1 + src/python_bindings/mod.rs | 9 +++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/benchmarks/bench_regex_guide.py b/benchmarks/bench_regex_guide.py index a47adae2..32316242 100644 --- a/benchmarks/bench_regex_guide.py +++ b/benchmarks/bench_regex_guide.py @@ -1,3 +1,6 @@ +from concurrent.futures import ThreadPoolExecutor + +import psutil from outlines_core.fsm.guide import RegexGuide from .common import setup_tokenizer @@ -14,7 +17,6 @@ "complex_span_constrained_relation_extraction": "(['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?\\s\\|\\s([^|\\(\\)\n]{1,})\\s\\|\\s['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?(\\s\\|\\s\\(([^|\\(\\)\n]{1,})\\s\\|\\s([^|\\(\\)\n]{1,})\\))*\\n)*", } - class RegexGuideBenchmark: params = regex_samples.keys() @@ -25,13 +27,33 @@ def setup(self, pattern_name): def time_regex_to_guide(self, pattern_name): RegexGuide.from_regex(self.pattern, self.tokenizer) + def time_regex_to_guide_parallel(self, pattern_name): + # Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks, + # this parallel case should be relatively close in runtime to one thread, but it is not, + # because of the GIL. + core_count = psutil.cpu_count(logical=False) + with ThreadPoolExecutor(max_workers=core_count) as executor: + list(executor.map(self._from_regex, [pattern_name] * core_count)) + + def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name): + # This test is to show, that if GIL's switch interval is set to be longer, then the parallel + # test's runtime on physical cores will be much closer to the one-threaded case. + import sys + sys.setswitchinterval(5) + + core_count = psutil.cpu_count(logical=False) + with ThreadPoolExecutor(max_workers=core_count) as executor: + list(executor.map(self._from_regex, [pattern_name] * core_count)) + + def _from_regex(self, pattern_name): + RegexGuide.from_regex(self.pattern, self.tokenizer) -class MemoryRegexGuideBenchmark: - params = ["simple_phone", "complex_span_constrained_relation_extraction"] + class MemoryRegexGuideBenchmark: + params = ["simple_phone", "complex_span_constrained_relation_extraction"] - def setup(self, pattern_name): - self.tokenizer = setup_tokenizer() - self.pattern = regex_samples[pattern_name] + def setup(self, pattern_name): + self.tokenizer = setup_tokenizer() + self.pattern = regex_samples[pattern_name] - def peakmem_regex_to_guide(self, pattern_name): - RegexGuide.from_regex(self.pattern, self.tokenizer) + def peakmem_regex_to_guide(self, pattern_name): + RegexGuide.from_regex(self.pattern, self.tokenizer) diff --git a/pyproject.toml b/pyproject.toml index 57090988..3bde97b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ test = [ "datasets", "pillow", "asv", + "psutil", "setuptools-rust", ] diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 76a9a1ec..84b8746d 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -80,14 +80,17 @@ pub struct PyIndex(Index); impl PyIndex { #[new] fn new( + py: Python<'_>, fsm_info: &PyFSMInfo, vocabulary: &PyVocabulary, eos_token_id: u32, frozen_tokens: FxHashSet, ) -> PyResult { - Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) - .map(PyIndex) - .map_err(Into::into) + py.allow_threads(|| { + Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens) + .map(PyIndex) + .map_err(Into::into) + }) } fn __reduce__(&self) -> PyResult<(PyObject, (Vec,))> {