From de75c871ee651358cbeb51ec74b44d22e604c206 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Fri, 14 Jun 2024 23:53:17 -0700 Subject: [PATCH 01/12] add ParallelBatchCollector --- mmap_ninja/src/mmap_ninja/parallel.py | 146 ++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 mmap_ninja/src/mmap_ninja/parallel.py diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py new file mode 100644 index 0000000..b56e9cb --- /dev/null +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -0,0 +1,146 @@ +from functools import partial +from tqdm.auto import tqdm + +try: + from joblib import Parallel, delayed +except ImportError: + Parallel, delayed = None, None + HAS_JOBLIB = False +else: + HAS_JOBLIB = True + + +class ParallelBatchCollector: + _parallel: Parallel = None + + def __init__(self, indexable, batch_size, n_jobs=None, verbose=False, **kwargs): + self.indexable, self._obj_length, self._num_batches = self.verify(indexable, batch_size) + self.batch_size = batch_size + + self._pbar = self._init_pbar(verbose) + self._parallel = self.begin(n_jobs, **kwargs) + self._batch_num = 0 + self._exhausted = False + + @staticmethod + def verify(indexable, batch_size): + try: + _ = indexable.__getitem__ + except AttributeError: + if callable(indexable): + indexable = _IndexableWrap(indexable) + else: + msg = 'indexable must implement __getitem__ or be callable and take one integer argument.' + raise TypeError(msg) + + try: + length = len(indexable) + except TypeError: + length = None + num_batches = None + else: + num_batches = length // batch_size + (length % batch_size != 0) + + return indexable, length, num_batches + + @staticmethod + def begin(n_jobs: int, **kwargs): + if n_jobs in (None, 1): + return + elif not HAS_JOBLIB: + msg = 'joblib is not installed. Install joblib or run with n_jobs=None to ignore parallelization.' + raise ImportError(msg) + + _parallel = Parallel(n_jobs=n_jobs, **kwargs) + _parallel.__enter__() + return _parallel + + def batches(self): + while not self.exhausted(): + yield self.collect_batch() + + def collect_batch(self): + if self._parallel is None: + batch = self._collect_no_parallel_batch() + else: + batch = self._collect_parallel_batch() + + self._update_pbar(batch) + return batch + + def _init_pbar(self, verbose): + if not verbose: + return None + return tqdm(total=self._obj_length) + + def _update_pbar(self, batch): + if self._pbar is not None: + self._pbar.update(len(batch)) + + def _collect_no_parallel_batch(self): + results = [_get_from_indexable(self.indexable, j) for j in self._rng()] + + if self.exhausted(results): + results = [r for r in results if r is not None] + + return results + + def _collect_parallel_batch(self): + func = delayed(partial(_get_from_indexable, self.indexable)) + + results = self._parallel(func(j) for j in self._rng()) + + if self.exhausted(results): + results = [r for r in results if r is not None] + self._parallel.__exit__(None, None, None) + + return results + + def exhausted(self, results=()): + self._exhausted = self._exhausted or any(r is None for r in results) or self.completed_batches() + return self._exhausted + + def completed_batches(self): + return self._num_batches is not None and self._batch_num == self._num_batches + + def _rng(self): + start = self.batch_size * self._batch_num + stop = self.batch_size * (1 + self._batch_num) + + self._batch_num += 1 + + return range(start, stop) + + +class _IndexableWrap: + def __init__(self, func): + self._func = func + + def __getitem__(self, item): + return self._func(item) + + @property + def wrapped(self): + return self._func + + +class _IndexableLengthWrap(_IndexableWrap): + def __init__(self, func, length): + super().__init__(func) + self.length = length + + def __len__(self): + return self.length + + +def make_indexable(func, length=None): + if length is not None: + return _IndexableLengthWrap(func, length) + return _IndexableWrap(func) + + +def _get_from_indexable(indexable, item,): + try: + return indexable[item] + except (IndexError, KeyError): + return None \ No newline at end of file From 33a02b91caa11d612e9aedc3012a3ff4f68d1157 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Fri, 14 Jun 2024 23:54:02 -0700 Subject: [PATCH 02/12] add from_indexable_base --- mmap_ninja/src/mmap_ninja/base.py | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/mmap_ninja/src/mmap_ninja/base.py b/mmap_ninja/src/mmap_ninja/base.py index e445bd2..45ceb81 100644 --- a/mmap_ninja/src/mmap_ninja/base.py +++ b/mmap_ninja/src/mmap_ninja/base.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Union, Sequence, List +from .parallel import ParallelBatchCollector + def _bytes_to_int(inp: bytes, fmt: str = " int: """ @@ -229,6 +231,40 @@ def from_generator_base(out_dir, sample_generator, batch_size, batch_ctor, exten return memmap +def from_indexable_base(out_dir, indexable, batch_size, batch_ctor, extend_fn=None, n_jobs=None, verbose=False, **kwargs): + """ + Creates an output from a generator, flushing every batch to disk. + + :param out_dir: The output directory. + :param indexable: An object that supports __getitem__. indexable[i] will become memmap[i] + :param batch_size: The batch size, which controls how often the output should be written to disk. + :param batch_ctor: The constructor used to initialize the output. + :param extend_fn: Functon to call when doing .extend. By default, this will call memmap.extend(samples) + :param n_jobs: number of jobs to iterate through indexable with. Default=None corresponds to no parallelization. + :param verbose: whether to print progress meter. + :param kwargs: Additional keyword arguments to be passed when initializing the output. + :return: + """ + out_dir = Path(out_dir) + out_dir.mkdir(exist_ok=True) + memmap = None + if kwargs.pop("verbose", False): + from tqdm.auto import tqdm + + batch_collector = ParallelBatchCollector(indexable, batch_size, n_jobs, verbose) + + for samples in batch_collector.batches(): + if memmap is None: + memmap = batch_ctor(out_dir, samples, **kwargs) + else: + if extend_fn is not None: + extend_fn(memmap, samples) + else: + memmap.extend(samples) + + return memmap + + class Wrapped: def __init__(self, data, wrapper_fn, copy_before_wrapper_fn=True): self.data = data From 575132fdb548f520ea66d727984d3e8309bd44ca Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Fri, 14 Jun 2024 23:57:10 -0700 Subject: [PATCH 03/12] add from_indexable clsmethod --- mmap_ninja/src/mmap_ninja/ragged.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mmap_ninja/src/mmap_ninja/ragged.py b/mmap_ninja/src/mmap_ninja/ragged.py index 4fcd84f..666f767 100644 --- a/mmap_ninja/src/mmap_ninja/ragged.py +++ b/mmap_ninja/src/mmap_ninja/ragged.py @@ -190,3 +190,14 @@ def from_generator(cls, out_dir: Union[str, Path], sample_generator, batch_size: batch_ctor=cls.from_lists, **kwargs, ) + + @classmethod + def from_indexable(cls, out_dir: Union[str, Path], indexable, batch_size: int, verbose=False, **kwargs): + return base.from_indexable_base( + out_dir=out_dir, + indexable=indexable, + batch_size=batch_size, + verbose=verbose, + batch_ctor=cls.from_lists, + **kwargs, + ) From 3ff90b6d318acb47715f0bca284e2ef5a07a238e Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Fri, 14 Jun 2024 23:57:26 -0700 Subject: [PATCH 04/12] add optional dep --- mmap_ninja/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/mmap_ninja/pyproject.toml b/mmap_ninja/pyproject.toml index 3dbda9d..ff5ba5e 100644 --- a/mmap_ninja/pyproject.toml +++ b/mmap_ninja/pyproject.toml @@ -32,3 +32,4 @@ Homepage = "https://github.com/hristo-vrigazov/mmap.ninja" [project.optional-dependencies] dev = ["pytest", "tqdm", "pytest-dotenv"] +parallel = ["joblib"] From 7e9a44ac571074e37b4f8abfc1a0132fc003f740 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Sat, 15 Jun 2024 00:02:54 -0700 Subject: [PATCH 05/12] fix whitespace err --- mmap_ninja/src/mmap_ninja/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py index b56e9cb..194b26f 100644 --- a/mmap_ninja/src/mmap_ninja/parallel.py +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -143,4 +143,4 @@ def _get_from_indexable(indexable, item,): try: return indexable[item] except (IndexError, KeyError): - return None \ No newline at end of file + return None From 1c73be5a9c8d5b9b3da55e4ee8d7bdb62f60ae47 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Sat, 15 Jun 2024 00:15:20 -0700 Subject: [PATCH 06/12] use exhausted instead of None --- mmap_ninja/src/mmap_ninja/parallel.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py index 194b26f..ee99c23 100644 --- a/mmap_ninja/src/mmap_ninja/parallel.py +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -10,6 +10,9 @@ HAS_JOBLIB = True +EXHAUSTED = '__EXHAUSTED__' + + class ParallelBatchCollector: _parallel: Parallel = None @@ -81,7 +84,7 @@ def _collect_no_parallel_batch(self): results = [_get_from_indexable(self.indexable, j) for j in self._rng()] if self.exhausted(results): - results = [r for r in results if r is not None] + results = [r for r in results if r != EXHAUSTED] return results @@ -91,13 +94,13 @@ def _collect_parallel_batch(self): results = self._parallel(func(j) for j in self._rng()) if self.exhausted(results): - results = [r for r in results if r is not None] + results = [r for r in results if r != EXHAUSTED] self._parallel.__exit__(None, None, None) return results def exhausted(self, results=()): - self._exhausted = self._exhausted or any(r is None for r in results) or self.completed_batches() + self._exhausted = self._exhausted or any(r == EXHAUSTED for r in results) or self.completed_batches() return self._exhausted def completed_batches(self): @@ -143,4 +146,4 @@ def _get_from_indexable(indexable, item,): try: return indexable[item] except (IndexError, KeyError): - return None + return EXHAUSTED From 9bbeb2862ea11f04ce45d2569ee0fedda57658b5 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Sat, 15 Jun 2024 00:31:03 -0700 Subject: [PATCH 07/12] fix EXHAUSTED check --- mmap_ninja/src/mmap_ninja/parallel.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py index ee99c23..0ab8484 100644 --- a/mmap_ninja/src/mmap_ninja/parallel.py +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -84,7 +84,7 @@ def _collect_no_parallel_batch(self): results = [_get_from_indexable(self.indexable, j) for j in self._rng()] if self.exhausted(results): - results = [r for r in results if r != EXHAUSTED] + results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] return results @@ -94,13 +94,18 @@ def _collect_parallel_batch(self): results = self._parallel(func(j) for j in self._rng()) if self.exhausted(results): - results = [r for r in results if r != EXHAUSTED] + results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] self._parallel.__exit__(None, None, None) return results def exhausted(self, results=()): - self._exhausted = self._exhausted or any(r == EXHAUSTED for r in results) or self.completed_batches() + self._exhausted = ( + self._exhausted or + any(isinstance(r, str) and r == EXHAUSTED for r in results) or + self.completed_batches() + ) + return self._exhausted def completed_batches(self): From cca0d45345927ca818ee8e267d3c14f793365b20 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Sat, 15 Jun 2024 00:34:15 -0700 Subject: [PATCH 08/12] add n_jobs kwarg --- mmap_ninja/src/mmap_ninja/ragged.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmap_ninja/src/mmap_ninja/ragged.py b/mmap_ninja/src/mmap_ninja/ragged.py index 666f767..7cddb1b 100644 --- a/mmap_ninja/src/mmap_ninja/ragged.py +++ b/mmap_ninja/src/mmap_ninja/ragged.py @@ -192,11 +192,12 @@ def from_generator(cls, out_dir: Union[str, Path], sample_generator, batch_size: ) @classmethod - def from_indexable(cls, out_dir: Union[str, Path], indexable, batch_size: int, verbose=False, **kwargs): + def from_indexable(cls, out_dir: Union[str, Path], indexable, batch_size: int, n_jobs=None, verbose=False, **kwargs): return base.from_indexable_base( out_dir=out_dir, indexable=indexable, batch_size=batch_size, + n_jobs=n_jobs, verbose=verbose, batch_ctor=cls.from_lists, **kwargs, From 6205cac2174af1ffb7536d2dd8875c386cc76827 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Mon, 17 Jun 2024 13:21:45 -0700 Subject: [PATCH 09/12] change exhausted --- mmap_ninja/src/mmap_ninja/parallel.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mmap_ninja/src/mmap_ninja/parallel.py b/mmap_ninja/src/mmap_ninja/parallel.py index 0ab8484..1094982 100644 --- a/mmap_ninja/src/mmap_ninja/parallel.py +++ b/mmap_ninja/src/mmap_ninja/parallel.py @@ -1,3 +1,4 @@ +from enum import Enum from functools import partial from tqdm.auto import tqdm @@ -10,7 +11,11 @@ HAS_JOBLIB = True -EXHAUSTED = '__EXHAUSTED__' +class _Exhausted(Enum): + exhausted = 'EXHAUSTED' + + +EXHAUSTED = _Exhausted.exhausted class ParallelBatchCollector: @@ -84,7 +89,7 @@ def _collect_no_parallel_batch(self): results = [_get_from_indexable(self.indexable, j) for j in self._rng()] if self.exhausted(results): - results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] + results = [r for r in results if r is not EXHAUSTED] return results @@ -94,7 +99,7 @@ def _collect_parallel_batch(self): results = self._parallel(func(j) for j in self._rng()) if self.exhausted(results): - results = [r for r in results if not isinstance(r, str) or r != EXHAUSTED] + results = [r for r in results if r is not EXHAUSTED] self._parallel.__exit__(None, None, None) return results @@ -102,7 +107,7 @@ def _collect_parallel_batch(self): def exhausted(self, results=()): self._exhausted = ( self._exhausted or - any(isinstance(r, str) and r == EXHAUSTED for r in results) or + any(r is EXHAUSTED for r in results) or self.completed_batches() ) From a6c922d8e63ae7142cbb91ba0b5377c0d0843d3a Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Mon, 17 Jun 2024 13:41:45 -0700 Subject: [PATCH 10/12] update docstring --- mmap_ninja/src/mmap_ninja/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmap_ninja/src/mmap_ninja/base.py b/mmap_ninja/src/mmap_ninja/base.py index 45ceb81..c873aa1 100644 --- a/mmap_ninja/src/mmap_ninja/base.py +++ b/mmap_ninja/src/mmap_ninja/base.py @@ -233,13 +233,14 @@ def from_generator_base(out_dir, sample_generator, batch_size, batch_ctor, exten def from_indexable_base(out_dir, indexable, batch_size, batch_ctor, extend_fn=None, n_jobs=None, verbose=False, **kwargs): """ - Creates an output from a generator, flushing every batch to disk. + Creates an output from an indexable object, flushing every batch to disk. Can be done in parallel. + indexable[i] (or indexable(i)) will become memmap[i]. :param out_dir: The output directory. - :param indexable: An object that supports __getitem__. indexable[i] will become memmap[i] + :param indexable: An object that supports __getitem__ or a function that takes one integer argument. :param batch_size: The batch size, which controls how often the output should be written to disk. :param batch_ctor: The constructor used to initialize the output. - :param extend_fn: Functon to call when doing .extend. By default, this will call memmap.extend(samples) + :param extend_fn: Functon to call when doing .extend. By default, this will call memmap.extend(samples). :param n_jobs: number of jobs to iterate through indexable with. Default=None corresponds to no parallelization. :param verbose: whether to print progress meter. :param kwargs: Additional keyword arguments to be passed when initializing the output. From 1d82b57887a0da0712a0b00fd282e2accfd95359 Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Mon, 17 Jun 2024 14:02:04 -0700 Subject: [PATCH 11/12] update verbose kwarg usage --- mmap_ninja/src/mmap_ninja/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmap_ninja/src/mmap_ninja/base.py b/mmap_ninja/src/mmap_ninja/base.py index c873aa1..ff003f5 100644 --- a/mmap_ninja/src/mmap_ninja/base.py +++ b/mmap_ninja/src/mmap_ninja/base.py @@ -249,8 +249,6 @@ def from_indexable_base(out_dir, indexable, batch_size, batch_ctor, extend_fn=No out_dir = Path(out_dir) out_dir.mkdir(exist_ok=True) memmap = None - if kwargs.pop("verbose", False): - from tqdm.auto import tqdm batch_collector = ParallelBatchCollector(indexable, batch_size, n_jobs, verbose) From 3b05ed168bd45f9bfd4550d9bc219a310bb63b7c Mon Sep 17 00:00:00 2001 From: avishaihalev Date: Mon, 17 Jun 2024 14:06:15 -0700 Subject: [PATCH 12/12] add from_indexable tests --- mmap_ninja/tests/test_ragged.py | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/mmap_ninja/tests/test_ragged.py b/mmap_ninja/tests/test_ragged.py index 1e79f37..b4e94ab 100644 --- a/mmap_ninja/tests/test_ragged.py +++ b/mmap_ninja/tests/test_ragged.py @@ -96,6 +96,57 @@ def test_from_generator(tmp_path, n): assert np.allclose(np.ones(12) * i, memmap[i]) +@pytest.fixture +def indexable_obj(request): + length, has_length = request.param + + class _Indexable: + def __init__(self, _length, _has_length): + self.length = _length + self.has_length = _has_length + + def __len__(self): + if not self.has_length: + raise TypeError + return self.length + + def __getitem__(self, item): + if 0 <= item < self.length: + return np.ones(12) * item + raise IndexError(item) + + return _Indexable(length, has_length) + + +@pytest.mark.parametrize("n, indexable_obj", [(30, (30, True)), (3, (3, False))], indirect=["indexable_obj"]) +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_from_indexable_obj(tmp_path, n, indexable_obj, n_jobs): + memmap = RaggedMmap.from_indexable(tmp_path / "strings_memmap", indexable_obj, 4, n_jobs=n_jobs, verbose=True) + for i in range(n): + assert np.allclose(np.ones(12) * i, memmap[i]) + + +@pytest.fixture +def indexable_func(request): + + total = request.param + + def func(item): + if 0 <= item < total: + return np.ones(12) * item + raise IndexError(item) + + return func + + +@pytest.mark.parametrize("n, indexable_func", [(30, 30), (3, 3)], indirect=["indexable_func"]) +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_from_indexable_func(tmp_path, n, indexable_func, n_jobs): + memmap = RaggedMmap.from_indexable(tmp_path / "strings_memmap", indexable_func, 4, n_jobs=n_jobs, verbose=True) + for i in range(n): + assert np.allclose(np.ones(12) * i, memmap[i]) + + def test_nd_case(tmp_path): simple = [ np.array([[11, 13], [-1, 17]]),