Skip to content

Commit

Permalink
Merge pull request #17 from ahalev/ahalev_parallel
Browse files Browse the repository at this point in the history
Define a memmap from an indexable object + allow it to be done in parallel.
  • Loading branch information
hristo-vrigazov authored Jun 19, 2024
2 parents f2b40c1 + 3b05ed1 commit 84e13a7
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmap_ninja/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ Homepage = "https://github.com/hristo-vrigazov/mmap.ninja"

[project.optional-dependencies]
dev = ["pytest", "tqdm", "pytest-dotenv"]
parallel = ["joblib"]
35 changes: 35 additions & 0 deletions mmap_ninja/src/mmap_ninja/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
from typing import Union, Sequence, List

from .parallel import ParallelBatchCollector


def _bytes_to_int(inp: bytes, fmt: str = "<i") -> int:
"""
Expand Down Expand Up @@ -229,6 +231,39 @@ def from_generator_base(out_dir, sample_generator, batch_size, batch_ctor, exten
return memmap


def from_indexable_base(out_dir, indexable, batch_size, batch_ctor, extend_fn=None, n_jobs=None, verbose=False, **kwargs):
"""
Creates an output from an indexable object, flushing every batch to disk. Can be done in parallel.
indexable[i] (or indexable(i)) will become memmap[i].
:param out_dir: The output directory.
:param indexable: An object that supports __getitem__ or a function that takes one integer argument.
:param batch_size: The batch size, which controls how often the output should be written to disk.
:param batch_ctor: The constructor used to initialize the output.
:param extend_fn: Functon to call when doing .extend. By default, this will call memmap.extend(samples).
:param n_jobs: number of jobs to iterate through indexable with. Default=None corresponds to no parallelization.
:param verbose: whether to print progress meter.
:param kwargs: Additional keyword arguments to be passed when initializing the output.
:return:
"""
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True)
memmap = None

batch_collector = ParallelBatchCollector(indexable, batch_size, n_jobs, verbose)

for samples in batch_collector.batches():
if memmap is None:
memmap = batch_ctor(out_dir, samples, **kwargs)
else:
if extend_fn is not None:
extend_fn(memmap, samples)
else:
memmap.extend(samples)

return memmap


class Wrapped:
def __init__(self, data, wrapper_fn, copy_before_wrapper_fn=True):
self.data = data
Expand Down
159 changes: 159 additions & 0 deletions mmap_ninja/src/mmap_ninja/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from enum import Enum
from functools import partial
from tqdm.auto import tqdm

try:
from joblib import Parallel, delayed
except ImportError:
Parallel, delayed = None, None
HAS_JOBLIB = False
else:
HAS_JOBLIB = True


class _Exhausted(Enum):
exhausted = 'EXHAUSTED'


EXHAUSTED = _Exhausted.exhausted


class ParallelBatchCollector:
_parallel: Parallel = None

def __init__(self, indexable, batch_size, n_jobs=None, verbose=False, **kwargs):
self.indexable, self._obj_length, self._num_batches = self.verify(indexable, batch_size)
self.batch_size = batch_size

self._pbar = self._init_pbar(verbose)
self._parallel = self.begin(n_jobs, **kwargs)
self._batch_num = 0
self._exhausted = False

@staticmethod
def verify(indexable, batch_size):
try:
_ = indexable.__getitem__
except AttributeError:
if callable(indexable):
indexable = _IndexableWrap(indexable)
else:
msg = 'indexable must implement __getitem__ or be callable and take one integer argument.'
raise TypeError(msg)

try:
length = len(indexable)
except TypeError:
length = None
num_batches = None
else:
num_batches = length // batch_size + (length % batch_size != 0)

return indexable, length, num_batches

@staticmethod
def begin(n_jobs: int, **kwargs):
if n_jobs in (None, 1):
return
elif not HAS_JOBLIB:
msg = 'joblib is not installed. Install joblib or run with n_jobs=None to ignore parallelization.'
raise ImportError(msg)

_parallel = Parallel(n_jobs=n_jobs, **kwargs)
_parallel.__enter__()
return _parallel

def batches(self):
while not self.exhausted():
yield self.collect_batch()

def collect_batch(self):
if self._parallel is None:
batch = self._collect_no_parallel_batch()
else:
batch = self._collect_parallel_batch()

self._update_pbar(batch)
return batch

def _init_pbar(self, verbose):
if not verbose:
return None
return tqdm(total=self._obj_length)

def _update_pbar(self, batch):
if self._pbar is not None:
self._pbar.update(len(batch))

def _collect_no_parallel_batch(self):
results = [_get_from_indexable(self.indexable, j) for j in self._rng()]

if self.exhausted(results):
results = [r for r in results if r is not EXHAUSTED]

return results

def _collect_parallel_batch(self):
func = delayed(partial(_get_from_indexable, self.indexable))

results = self._parallel(func(j) for j in self._rng())

if self.exhausted(results):
results = [r for r in results if r is not EXHAUSTED]
self._parallel.__exit__(None, None, None)

return results

def exhausted(self, results=()):
self._exhausted = (
self._exhausted or
any(r is EXHAUSTED for r in results) or
self.completed_batches()
)

return self._exhausted

def completed_batches(self):
return self._num_batches is not None and self._batch_num == self._num_batches

def _rng(self):
start = self.batch_size * self._batch_num
stop = self.batch_size * (1 + self._batch_num)

self._batch_num += 1

return range(start, stop)


class _IndexableWrap:
def __init__(self, func):
self._func = func

def __getitem__(self, item):
return self._func(item)

@property
def wrapped(self):
return self._func


class _IndexableLengthWrap(_IndexableWrap):
def __init__(self, func, length):
super().__init__(func)
self.length = length

def __len__(self):
return self.length


def make_indexable(func, length=None):
if length is not None:
return _IndexableLengthWrap(func, length)
return _IndexableWrap(func)


def _get_from_indexable(indexable, item,):
try:
return indexable[item]
except (IndexError, KeyError):
return EXHAUSTED
12 changes: 12 additions & 0 deletions mmap_ninja/src/mmap_ninja/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,15 @@ def from_generator(cls, out_dir: Union[str, Path], sample_generator, batch_size:
batch_ctor=cls.from_lists,
**kwargs,
)

@classmethod
def from_indexable(cls, out_dir: Union[str, Path], indexable, batch_size: int, n_jobs=None, verbose=False, **kwargs):
return base.from_indexable_base(
out_dir=out_dir,
indexable=indexable,
batch_size=batch_size,
n_jobs=n_jobs,
verbose=verbose,
batch_ctor=cls.from_lists,
**kwargs,
)
51 changes: 51 additions & 0 deletions mmap_ninja/tests/test_ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,57 @@ def test_from_generator(tmp_path, n):
assert np.allclose(np.ones(12) * i, memmap[i])


@pytest.fixture
def indexable_obj(request):
length, has_length = request.param

class _Indexable:
def __init__(self, _length, _has_length):
self.length = _length
self.has_length = _has_length

def __len__(self):
if not self.has_length:
raise TypeError
return self.length

def __getitem__(self, item):
if 0 <= item < self.length:
return np.ones(12) * item
raise IndexError(item)

return _Indexable(length, has_length)


@pytest.mark.parametrize("n, indexable_obj", [(30, (30, True)), (3, (3, False))], indirect=["indexable_obj"])
@pytest.mark.parametrize("n_jobs", [1, 2])
def test_from_indexable_obj(tmp_path, n, indexable_obj, n_jobs):
memmap = RaggedMmap.from_indexable(tmp_path / "strings_memmap", indexable_obj, 4, n_jobs=n_jobs, verbose=True)
for i in range(n):
assert np.allclose(np.ones(12) * i, memmap[i])


@pytest.fixture
def indexable_func(request):

total = request.param

def func(item):
if 0 <= item < total:
return np.ones(12) * item
raise IndexError(item)

return func


@pytest.mark.parametrize("n, indexable_func", [(30, 30), (3, 3)], indirect=["indexable_func"])
@pytest.mark.parametrize("n_jobs", [1, 2])
def test_from_indexable_func(tmp_path, n, indexable_func, n_jobs):
memmap = RaggedMmap.from_indexable(tmp_path / "strings_memmap", indexable_func, 4, n_jobs=n_jobs, verbose=True)
for i in range(n):
assert np.allclose(np.ones(12) * i, memmap[i])


def test_nd_case(tmp_path):
simple = [
np.array([[11, 13], [-1, 17]]),
Expand Down

0 comments on commit 84e13a7

Please sign in to comment.