diff --git a/changelog.md b/changelog.md index 91838e4d3..e46fa5761 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,10 @@ # Unreleased +### Added + +- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster. + ### Fixed - Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor diff --git a/edsnlp/data/parquet.py b/edsnlp/data/parquet.py index 647fd4817..11771cb64 100644 --- a/edsnlp/data/parquet.py +++ b/edsnlp/data/parquet.py @@ -32,6 +32,7 @@ def __init__( shuffle: Literal["dataset", "fragment", False] = False, seed: Optional[int] = None, loop: bool = False, + work_unit: Literal["record", "fragment"] = "record", ): super().__init__() self.shuffle = shuffle @@ -41,6 +42,11 @@ def __init__( seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop + self.work_unit = work_unit + assert not (work_unit == "fragment" and shuffle == "dataset"), ( + "Cannot shuffle at the dataset level and dispatch tasks at the " + "fragment level. Set shuffle='fragment' or work_unit='record'." + ) # Either the filesystem has not been passed # or the path is a URL (e.g. s3://) => we need to infer the filesystem self.fs, self.path = normalize_fs_path(filesystem, path) @@ -53,23 +59,41 @@ def __init__( ) def read_fragment(self, fragment: ParquetFileFragment) -> Iterable[Dict]: - return dl_to_ld(fragment.to_table().to_pydict()) + return ( + doc + for batch in fragment.scanner().to_reader() + for doc in dl_to_ld(batch.to_pydict()) + ) + + def extract_task(self, item): + if self.work_unit == "fragment": + records = self.read_fragment(item) + if self.shuffle == "fragment": + records = shuffle(records, self.rng) + yield from records + else: + yield item def read_records(self) -> Iterable[Any]: while True: files = self.fragments if self.shuffle == "fragment": for file in shuffle(files, self.rng): - records = shuffle(self.read_fragment(file), self.rng) - yield from records + if self.work_unit == "fragment": + yield file + else: + yield from shuffle(self.read_fragment(file), self.rng) yield FragmentEndSentinel(file.path) elif self.shuffle == "dataset": + assert self.work_unit == "record" records = (line for file in files for line in self.read_fragment(file)) yield from shuffle(records, self.rng) else: for file in files: - records = list(self.read_fragment(file)) - yield from records + if self.work_unit == "fragment": + yield file + else: + yield from self.read_fragment(file) yield FragmentEndSentinel(file.path) yield DatasetEndSentinel() if not self.loop: @@ -158,6 +182,7 @@ def read_parquet( shuffle: Literal["dataset", "fragment", False] = False, seed: Optional[int] = None, loop: bool = False, + work_unit: Literal["record", "fragment"] = "record", **kwargs, ) -> Stream: """ @@ -211,6 +236,18 @@ def read_parquet( The seed to use for shuffling. loop: bool Whether to loop over the data indefinitely. + work_unit: Literal["record", "fragment"] + Only affects the multiprocessing mode. If "record", every worker will start to + read the same parquet file and yield each every num_workers-th record, starting + at an offset each. For instance, if num_workers=2, the first worker will read + the 1st, 3rd, 5th, ... records, while the second worker will read the 2nd, 4th, + 6th, ... records of the first parquet file. + + If "fragment", each worker will read a different parquet file. For instance, the + first worker will every record of the 1st parquet file, the second worker will + read every record of the 2nd parquet file, and so on. This way, no record is + "wasted" and every record loaded in memory is yielded. + converter: Optional[AsList[Union[str, Callable]]] Converters to use to convert the parquet rows of the data source to Doc objects These are documented on the [Converters](/data/converters) page. @@ -237,6 +274,7 @@ def read_parquet( shuffle=shuffle, seed=seed, loop=loop, + work_unit=work_unit, ) ) if converter: diff --git a/edsnlp/processing/multiprocessing.py b/edsnlp/processing/multiprocessing.py index 673940d3d..50b50c92d 100644 --- a/edsnlp/processing/multiprocessing.py +++ b/edsnlp/processing/multiprocessing.py @@ -509,13 +509,16 @@ def iter_tasks(self, stage, stop_mode=False): return task_idx = 0 for item in iter(self.stream.reader.read_records()): - if self.stop: + if self.stop: # pragma: no cover raise StopSignal() if isinstance(item, StreamSentinel): yield item continue if task_idx % pool_size == worker_idx: - yield item + for item in self.stream.reader.extract_task(item): + if self.stop: # pragma: no cover + raise StopSignal() + yield item task_idx += 1 else: @@ -751,7 +754,7 @@ def iter_tasks(self, stage, stop_mode=False): ] # Get items from the previous stage while self.num_producers_alive[stage] > 0: - if self.stop and not stop_mode: + if self.stop and not stop_mode: # pragma: no cover raise StopSignal() offset = (offset + 1) % len(queues) diff --git a/edsnlp/processing/simple.py b/edsnlp/processing/simple.py index b3258b78b..5825c36ee 100644 --- a/edsnlp/processing/simple.py +++ b/edsnlp/processing/simple.py @@ -66,6 +66,15 @@ def process(): with bar, stream.eval(): items = reader.read_records() + items = ( + task + for item in items + for task in ( + (item,) + if isinstance(item, StreamSentinel) + else reader.extract_task(item) + ) + ) for stage_idx, stage in enumerate(stages): for op in stage.cpu_ops: diff --git a/edsnlp/processing/spark.py b/edsnlp/processing/spark.py index a0433c9d1..71494ad35 100644 --- a/edsnlp/processing/spark.py +++ b/edsnlp/processing/spark.py @@ -120,6 +120,8 @@ def process_partition(items): # pragma: no cover else: items = (item.asDict(recursive=True) for item in items) + items = (task for item in items for task in stream.reader.extract_task(item)) + for stage_idx, stage in enumerate(stages): for op in stage.cpu_ops: items = op(items) diff --git a/tests/data/test_parquet.py b/tests/data/test_parquet.py index 634e846c1..6caeb1648 100644 --- a/tests/data/test_parquet.py +++ b/tests/data/test_parquet.py @@ -4,6 +4,7 @@ import pyarrow.dataset import pyarrow.fs import pytest +from confit.utils.random import set_seed from typing_extensions import Literal import edsnlp @@ -355,6 +356,32 @@ def test_read_shuffle_loop( ] +@pytest.mark.parametrize("num_cpu_workers", [0, 2]) +@pytest.mark.parametrize("work_unit", ["record", "fragment"]) +@pytest.mark.parametrize("shuffle", [False, "dataset", "fragment"]) +def test_read_work_unit( + num_cpu_workers, + work_unit: Literal["record", "fragment"], + shuffle: Literal[False, "dataset", "fragment"], +): + if shuffle == "dataset" and work_unit == "fragment": + pytest.skip("Dataset-level shuffle is not supported with fragment work unit") + input_dir = Path(__file__).parent.parent.resolve() / "resources" / "docs.parquet" + set_seed(42) + stream = edsnlp.data.read_parquet( + input_dir, work_unit=work_unit, shuffle=shuffle + ).set_processing( + num_cpu_workers=num_cpu_workers, + ) + stream = stream.map_batches( + lambda b: "|".join(sorted([x["note_id"] for x in b])), batch_size=1000 + ) + if work_unit == "fragment" and num_cpu_workers == 2 or num_cpu_workers == 0: + assert list(stream) == ["subfolder/doc-1|subfolder/doc-2|subfolder/doc-3"] + else: + assert list(stream) == ["subfolder/doc-1|subfolder/doc-3", "subfolder/doc-2"] + + @pytest.mark.parametrize( "num_cpu_workers,write_in_worker", [