Skip to content

Commit

Permalink
feat: support splitting parquet reading by fragment instead of rows i…
Browse files Browse the repository at this point in the history
…n mp mode
  • Loading branch information
percevalw committed Nov 28, 2024
1 parent 7ac9895 commit 14d1728
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 8 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

# Unreleased

### Added

- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster.

### Fixed

- Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor
Expand Down
48 changes: 43 additions & 5 deletions edsnlp/data/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
shuffle: Literal["dataset", "fragment", False] = False,
seed: Optional[int] = None,
loop: bool = False,
work_unit: Literal["record", "fragment"] = "record",
):
super().__init__()
self.shuffle = shuffle
Expand All @@ -41,6 +42,11 @@ def __init__(
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.loop = loop
self.work_unit = work_unit
assert not (work_unit == "fragment" and shuffle == "dataset"), (
"Cannot shuffle at the dataset level and dispatch tasks at the "
"fragment level. Set shuffle='fragment' or work_unit='record'."
)
# Either the filesystem has not been passed
# or the path is a URL (e.g. s3://) => we need to infer the filesystem
self.fs, self.path = normalize_fs_path(filesystem, path)
Expand All @@ -53,23 +59,41 @@ def __init__(
)

def read_fragment(self, fragment: ParquetFileFragment) -> Iterable[Dict]:
return dl_to_ld(fragment.to_table().to_pydict())
return (
doc
for batch in fragment.scanner().to_reader()
for doc in dl_to_ld(batch.to_pydict())
)

def extract_task(self, item):
if self.work_unit == "fragment":
records = self.read_fragment(item)
if self.shuffle == "fragment":
records = shuffle(records, self.rng)
yield from records
else:
yield item

def read_records(self) -> Iterable[Any]:
while True:
files = self.fragments
if self.shuffle == "fragment":
for file in shuffle(files, self.rng):
records = shuffle(self.read_fragment(file), self.rng)
yield from records
if self.work_unit == "fragment":
yield file
else:
yield from shuffle(self.read_fragment(file), self.rng)
yield FragmentEndSentinel(file.path)
elif self.shuffle == "dataset":
assert self.work_unit == "record"
records = (line for file in files for line in self.read_fragment(file))
yield from shuffle(records, self.rng)
else:
for file in files:
records = list(self.read_fragment(file))
yield from records
if self.work_unit == "fragment":
yield file
else:
yield from self.read_fragment(file)
yield FragmentEndSentinel(file.path)
yield DatasetEndSentinel()
if not self.loop:
Expand Down Expand Up @@ -158,6 +182,7 @@ def read_parquet(
shuffle: Literal["dataset", "fragment", False] = False,
seed: Optional[int] = None,
loop: bool = False,
work_unit: Literal["record", "fragment"] = "record",
**kwargs,
) -> Stream:
"""
Expand Down Expand Up @@ -211,6 +236,18 @@ def read_parquet(
The seed to use for shuffling.
loop: bool
Whether to loop over the data indefinitely.
work_unit: Literal["record", "fragment"]
Only affects the multiprocessing mode. If "record", every worker will start to
read the same parquet file and yield each every num_workers-th record, starting
at an offset each. For instance, if num_workers=2, the first worker will read
the 1st, 3rd, 5th, ... records, while the second worker will read the 2nd, 4th,
6th, ... records of the first parquet file.
If "fragment", each worker will read a different parquet file. For instance, the
first worker will every record of the 1st parquet file, the second worker will
read every record of the 2nd parquet file, and so on. This way, no record is
"wasted" and every record loaded in memory is yielded.
converter: Optional[AsList[Union[str, Callable]]]
Converters to use to convert the parquet rows of the data source to Doc objects
These are documented on the [Converters](/data/converters) page.
Expand All @@ -237,6 +274,7 @@ def read_parquet(
shuffle=shuffle,
seed=seed,
loop=loop,
work_unit=work_unit,
)
)
if converter:
Expand Down
9 changes: 6 additions & 3 deletions edsnlp/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,16 @@ def iter_tasks(self, stage, stop_mode=False):
return
task_idx = 0
for item in iter(self.stream.reader.read_records()):
if self.stop:
if self.stop: # pragma: no cover
raise StopSignal()
if isinstance(item, StreamSentinel):
yield item
continue
if task_idx % pool_size == worker_idx:
yield item
for item in self.stream.reader.extract_task(item):
if self.stop: # pragma: no cover
raise StopSignal()
yield item

task_idx += 1
else:
Expand Down Expand Up @@ -751,7 +754,7 @@ def iter_tasks(self, stage, stop_mode=False):
]
# Get items from the previous stage
while self.num_producers_alive[stage] > 0:
if self.stop and not stop_mode:
if self.stop and not stop_mode: # pragma: no cover
raise StopSignal()

offset = (offset + 1) % len(queues)
Expand Down
9 changes: 9 additions & 0 deletions edsnlp/processing/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def process():

with bar, stream.eval():
items = reader.read_records()
items = (
task
for item in items
for task in (
(item,)
if isinstance(item, StreamSentinel)
else reader.extract_task(item)
)
)

for stage_idx, stage in enumerate(stages):
for op in stage.cpu_ops:
Expand Down
2 changes: 2 additions & 0 deletions edsnlp/processing/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def process_partition(items): # pragma: no cover
else:
items = (item.asDict(recursive=True) for item in items)

items = (task for item in items for task in stream.reader.extract_task(item))

for stage_idx, stage in enumerate(stages):
for op in stage.cpu_ops:
items = op(items)
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyarrow.dataset
import pyarrow.fs
import pytest
from confit.utils.random import set_seed
from typing_extensions import Literal

import edsnlp
Expand Down Expand Up @@ -355,6 +356,32 @@ def test_read_shuffle_loop(
]


@pytest.mark.parametrize("num_cpu_workers", [0, 2])
@pytest.mark.parametrize("work_unit", ["record", "fragment"])
@pytest.mark.parametrize("shuffle", [False, "dataset", "fragment"])
def test_read_work_unit(
num_cpu_workers,
work_unit: Literal["record", "fragment"],
shuffle: Literal[False, "dataset", "fragment"],
):
if shuffle == "dataset" and work_unit == "fragment":
pytest.skip("Dataset-level shuffle is not supported with fragment work unit")
input_dir = Path(__file__).parent.parent.resolve() / "resources" / "docs.parquet"
set_seed(42)
stream = edsnlp.data.read_parquet(
input_dir, work_unit=work_unit, shuffle=shuffle
).set_processing(
num_cpu_workers=num_cpu_workers,
)
stream = stream.map_batches(
lambda b: "|".join(sorted([x["note_id"] for x in b])), batch_size=1000
)
if work_unit == "fragment" and num_cpu_workers == 2 or num_cpu_workers == 0:
assert list(stream) == ["subfolder/doc-1|subfolder/doc-2|subfolder/doc-3"]
else:
assert list(stream) == ["subfolder/doc-1|subfolder/doc-3", "subfolder/doc-2"]


@pytest.mark.parametrize(
"num_cpu_workers,write_in_worker",
[
Expand Down

0 comments on commit 14d1728

Please sign in to comment.