From 0299401c015dbe82179c8feedff122d52fa09589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 15:37:42 +0100 Subject: [PATCH 1/3] fix: seed random states when shuffling in data readers --- changelog.md | 3 +++ edsnlp/core/stream.py | 6 ++++++ edsnlp/data/base.py | 21 ++++----------------- edsnlp/data/json.py | 2 +- edsnlp/data/pandas.py | 1 + edsnlp/data/parquet.py | 4 ++-- edsnlp/data/polars.py | 1 + edsnlp/data/spark.py | 1 + edsnlp/data/standoff.py | 1 + edsnlp/processing/spark.py | 2 +- 10 files changed, 21 insertions(+), 21 deletions(-) diff --git a/changelog.md b/changelog.md index 5390ea017..6eb6fa194 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,9 @@ - Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor - Support huggingface transformers that do not set `cls_token_id` and `sep_token_id` (we now also look for these tokens in the `special_tokens_map` and `vocab` mappings) - Fix changing scorers dict size issue when evaluating during training +- Seed random states (instead of using `random.RandomState()`) when shuffling in data readers : this is important for + 1. reproducibility + 2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers ## v0.14.0 (2024-11-14) diff --git a/edsnlp/core/stream.py b/edsnlp/core/stream.py index 53cbdaf73..1c8cd3069 100644 --- a/edsnlp/core/stream.py +++ b/edsnlp/core/stream.py @@ -791,6 +791,12 @@ def shuffle( else False ) stream = self + # Ensure that we have a "deterministic" random seed, meaning + # if the user sets a global seed before in the program and execute the + # same program twice, the shuffling should be the same in both cases. + # This is not garanteed by just creating random.Random() which does not + # account + seed = seed if seed is not None else random.getrandbits(32) if shuffle_reader: if shuffle_reader not in self.reader.emitted_sentinels: raise ValueError(f"Cannot shuffle by {shuffle_reader}") diff --git a/edsnlp/data/base.py b/edsnlp/data/base.py index a7d7a876b..08ba5cf63 100644 --- a/edsnlp/data/base.py +++ b/edsnlp/data/base.py @@ -21,23 +21,6 @@ class BaseReader: - """ - The BaseReader servers as a base class for all readers. It expects two methods: - - - `read_records` method which is called in the main process and should return a - generator of fragments (like filenames) with their estimated size (number of - documents) - - `unpack_tasks` method which is called in the worker processes and receives - batches of fragments and should return a list of dictionaries (one per - document), ready to be converted to a Doc object by the converter. - - Additionally, the subclass should define a `DATA_FIELDS` class attribute which - contains the names of all attributes that should not be copied when the reader is - copied to the worker processes. This is useful for example when the reader holds a - reference to a large object like a DataFrame that should not be copied to the - worker processes. - """ - DATA_FIELDS: Tuple[str] = () read_in_worker: bool emitted_sentinels: set @@ -47,6 +30,9 @@ class BaseReader: def read_records(self) -> Iterable[Any]: raise NotImplementedError() + def extract_task(self, item): + return [item] + def worker_copy(self): if self.read_in_worker: return self @@ -110,6 +96,7 @@ def __init__( ): super().__init__() self.shuffle = shuffle + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.emitted_sentinels = {"dataset"} self.loop = loop diff --git a/edsnlp/data/json.py b/edsnlp/data/json.py index 9a254a5f8..8560b9bf0 100644 --- a/edsnlp/data/json.py +++ b/edsnlp/data/json.py @@ -40,7 +40,6 @@ def __init__( ): super().__init__() self.shuffle = shuffle - self.rng = random.Random(seed) self.write_in_worker = write_in_worker self.emitted_sentinels = {"dataset"} self.loop = loop @@ -57,6 +56,7 @@ def __init__( self.keep_ipynb_checkpoints = keep_ipynb_checkpoints self.shuffle = shuffle self.loop = loop + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) for file in self.files: if not self.fs.exists(file): diff --git a/edsnlp/data/pandas.py b/edsnlp/data/pandas.py index 087a24d97..593669a59 100644 --- a/edsnlp/data/pandas.py +++ b/edsnlp/data/pandas.py @@ -27,6 +27,7 @@ def __init__( ): super().__init__() self.shuffle = shuffle + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.emitted_sentinels = {"dataset"} self.loop = loop diff --git a/edsnlp/data/parquet.py b/edsnlp/data/parquet.py index 8492d7712..647fd4817 100644 --- a/edsnlp/data/parquet.py +++ b/edsnlp/data/parquet.py @@ -38,6 +38,7 @@ def __init__( self.emitted_sentinels = {"dataset"} | ( set() if shuffle == "dataset" else {"fragment"} ) + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop # Either the filesystem has not been passed @@ -64,8 +65,7 @@ def read_records(self) -> Iterable[Any]: yield FragmentEndSentinel(file.path) elif self.shuffle == "dataset": records = (line for file in files for line in self.read_fragment(file)) - records = shuffle(records, self.rng) - yield from records + yield from shuffle(records, self.rng) else: for file in files: records = list(self.read_fragment(file)) diff --git a/edsnlp/data/polars.py b/edsnlp/data/polars.py index acbe5a674..8669b4450 100644 --- a/edsnlp/data/polars.py +++ b/edsnlp/data/polars.py @@ -29,6 +29,7 @@ def __init__( super().__init__() self.shuffle = shuffle self.emitted_sentinels = {"dataset"} + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop diff --git a/edsnlp/data/spark.py b/edsnlp/data/spark.py index bd7c02fb8..c0b81ca81 100644 --- a/edsnlp/data/spark.py +++ b/edsnlp/data/spark.py @@ -39,6 +39,7 @@ def __init__( self.data = data self.shuffle = shuffle self.emitted_sentinels = {"dataset"} + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop assert isinstance( diff --git a/edsnlp/data/standoff.py b/edsnlp/data/standoff.py index b2ec6bca0..bcecbf4bf 100644 --- a/edsnlp/data/standoff.py +++ b/edsnlp/data/standoff.py @@ -292,6 +292,7 @@ def __init__( super().__init__() self.shuffle = shuffle self.emitted_sentinels = {"dataset"} + seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop self.fs, self.path = normalize_fs_path(filesystem, path) diff --git a/edsnlp/processing/spark.py b/edsnlp/processing/spark.py index 0ba03b92e..a0433c9d1 100644 --- a/edsnlp/processing/spark.py +++ b/edsnlp/processing/spark.py @@ -80,7 +80,7 @@ def getActiveSession() -> Optional["SparkSession"]: # pragma: no cover df = reader.data assert not reader.loop, "Looping is not supported with Spark backend." df = ( - df.sample(fraction=1.0, seed=reader.rng.randbytes(32)) + df.sample(fraction=1.0, seed=reader.rng.getrandbits(32)) if reader.shuffle else df ) From 45e0234a69ad27c0127f7f8529d87bebc4664ed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 15:42:23 +0100 Subject: [PATCH 2/3] fix: bubble BaseComponent instantiation errors correctly --- changelog.md | 1 + edsnlp/pipes/base.py | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/changelog.md b/changelog.md index 6eb6fa194..91838e4d3 100644 --- a/changelog.md +++ b/changelog.md @@ -10,6 +10,7 @@ - Seed random states (instead of using `random.RandomState()`) when shuffling in data readers : this is important for 1. reproducibility 2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers +- Bubble BaseComponent instantiation errors correctly ## v0.14.0 (2024-11-14) diff --git a/edsnlp/pipes/base.py b/edsnlp/pipes/base.py index db1acf682..e92fc50b8 100644 --- a/edsnlp/pipes/base.py +++ b/edsnlp/pipes/base.py @@ -44,16 +44,22 @@ def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs): # If this component is missing the nlp argument, we curry it with the # provided arguments and return a CurriedFactory object. sig = inspect.signature(cls.__init__) - bound = sig.bind_partial(None, nlp, *args, **kwargs) - bound.arguments.pop("self", None) - if ( - "nlp" in sig.parameters - and sig.parameters["nlp"].default is sig.empty - and bound.arguments.get("nlp", sig.empty) is sig.empty - ): - return CurriedFactory(cls, bound.arguments) - if nlp is inspect.Signature.empty: - bound.arguments.pop("nlp", None) + try: + bound = sig.bind_partial(None, nlp, *args, **kwargs) + bound.arguments.pop("self", None) + if ( + "nlp" in sig.parameters + and sig.parameters["nlp"].default is sig.empty + and bound.arguments.get("nlp", sig.empty) is sig.empty + ): + return CurriedFactory(cls, bound.arguments) + if nlp is inspect.Signature.empty: + bound.arguments.pop("nlp", None) + except TypeError: # pragma: no cover + if nlp is inspect.Signature.empty: + super().__call__(*args, **kwargs) + else: + super().__call__(nlp, *args, **kwargs) return super().__call__(**bound.arguments) From 8e4b91d39518d477aca6359ffb21d30197fc61b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Thu, 28 Nov 2024 15:54:23 +0100 Subject: [PATCH 3/3] feat: support splitting parquet reading by fragment instead of rows in mp mode --- changelog.md | 4 +++ edsnlp/data/parquet.py | 48 +++++++++++++++++++++++++--- edsnlp/processing/multiprocessing.py | 9 ++++-- edsnlp/processing/simple.py | 9 ++++++ edsnlp/processing/spark.py | 2 ++ tests/data/test_parquet.py | 27 ++++++++++++++++ 6 files changed, 91 insertions(+), 8 deletions(-) diff --git a/changelog.md b/changelog.md index 91838e4d3..e46fa5761 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,10 @@ # Unreleased +### Added + +- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster. + ### Fixed - Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor diff --git a/edsnlp/data/parquet.py b/edsnlp/data/parquet.py index 647fd4817..11771cb64 100644 --- a/edsnlp/data/parquet.py +++ b/edsnlp/data/parquet.py @@ -32,6 +32,7 @@ def __init__( shuffle: Literal["dataset", "fragment", False] = False, seed: Optional[int] = None, loop: bool = False, + work_unit: Literal["record", "fragment"] = "record", ): super().__init__() self.shuffle = shuffle @@ -41,6 +42,11 @@ def __init__( seed = seed if seed is not None else random.getrandbits(32) self.rng = random.Random(seed) self.loop = loop + self.work_unit = work_unit + assert not (work_unit == "fragment" and shuffle == "dataset"), ( + "Cannot shuffle at the dataset level and dispatch tasks at the " + "fragment level. Set shuffle='fragment' or work_unit='record'." + ) # Either the filesystem has not been passed # or the path is a URL (e.g. s3://) => we need to infer the filesystem self.fs, self.path = normalize_fs_path(filesystem, path) @@ -53,23 +59,41 @@ def __init__( ) def read_fragment(self, fragment: ParquetFileFragment) -> Iterable[Dict]: - return dl_to_ld(fragment.to_table().to_pydict()) + return ( + doc + for batch in fragment.scanner().to_reader() + for doc in dl_to_ld(batch.to_pydict()) + ) + + def extract_task(self, item): + if self.work_unit == "fragment": + records = self.read_fragment(item) + if self.shuffle == "fragment": + records = shuffle(records, self.rng) + yield from records + else: + yield item def read_records(self) -> Iterable[Any]: while True: files = self.fragments if self.shuffle == "fragment": for file in shuffle(files, self.rng): - records = shuffle(self.read_fragment(file), self.rng) - yield from records + if self.work_unit == "fragment": + yield file + else: + yield from shuffle(self.read_fragment(file), self.rng) yield FragmentEndSentinel(file.path) elif self.shuffle == "dataset": + assert self.work_unit == "record" records = (line for file in files for line in self.read_fragment(file)) yield from shuffle(records, self.rng) else: for file in files: - records = list(self.read_fragment(file)) - yield from records + if self.work_unit == "fragment": + yield file + else: + yield from self.read_fragment(file) yield FragmentEndSentinel(file.path) yield DatasetEndSentinel() if not self.loop: @@ -158,6 +182,7 @@ def read_parquet( shuffle: Literal["dataset", "fragment", False] = False, seed: Optional[int] = None, loop: bool = False, + work_unit: Literal["record", "fragment"] = "record", **kwargs, ) -> Stream: """ @@ -211,6 +236,18 @@ def read_parquet( The seed to use for shuffling. loop: bool Whether to loop over the data indefinitely. + work_unit: Literal["record", "fragment"] + Only affects the multiprocessing mode. If "record", every worker will start to + read the same parquet file and yield each every num_workers-th record, starting + at an offset each. For instance, if num_workers=2, the first worker will read + the 1st, 3rd, 5th, ... records, while the second worker will read the 2nd, 4th, + 6th, ... records of the first parquet file. + + If "fragment", each worker will read a different parquet file. For instance, the + first worker will every record of the 1st parquet file, the second worker will + read every record of the 2nd parquet file, and so on. This way, no record is + "wasted" and every record loaded in memory is yielded. + converter: Optional[AsList[Union[str, Callable]]] Converters to use to convert the parquet rows of the data source to Doc objects These are documented on the [Converters](/data/converters) page. @@ -237,6 +274,7 @@ def read_parquet( shuffle=shuffle, seed=seed, loop=loop, + work_unit=work_unit, ) ) if converter: diff --git a/edsnlp/processing/multiprocessing.py b/edsnlp/processing/multiprocessing.py index 673940d3d..50b50c92d 100644 --- a/edsnlp/processing/multiprocessing.py +++ b/edsnlp/processing/multiprocessing.py @@ -509,13 +509,16 @@ def iter_tasks(self, stage, stop_mode=False): return task_idx = 0 for item in iter(self.stream.reader.read_records()): - if self.stop: + if self.stop: # pragma: no cover raise StopSignal() if isinstance(item, StreamSentinel): yield item continue if task_idx % pool_size == worker_idx: - yield item + for item in self.stream.reader.extract_task(item): + if self.stop: # pragma: no cover + raise StopSignal() + yield item task_idx += 1 else: @@ -751,7 +754,7 @@ def iter_tasks(self, stage, stop_mode=False): ] # Get items from the previous stage while self.num_producers_alive[stage] > 0: - if self.stop and not stop_mode: + if self.stop and not stop_mode: # pragma: no cover raise StopSignal() offset = (offset + 1) % len(queues) diff --git a/edsnlp/processing/simple.py b/edsnlp/processing/simple.py index b3258b78b..5825c36ee 100644 --- a/edsnlp/processing/simple.py +++ b/edsnlp/processing/simple.py @@ -66,6 +66,15 @@ def process(): with bar, stream.eval(): items = reader.read_records() + items = ( + task + for item in items + for task in ( + (item,) + if isinstance(item, StreamSentinel) + else reader.extract_task(item) + ) + ) for stage_idx, stage in enumerate(stages): for op in stage.cpu_ops: diff --git a/edsnlp/processing/spark.py b/edsnlp/processing/spark.py index a0433c9d1..71494ad35 100644 --- a/edsnlp/processing/spark.py +++ b/edsnlp/processing/spark.py @@ -120,6 +120,8 @@ def process_partition(items): # pragma: no cover else: items = (item.asDict(recursive=True) for item in items) + items = (task for item in items for task in stream.reader.extract_task(item)) + for stage_idx, stage in enumerate(stages): for op in stage.cpu_ops: items = op(items) diff --git a/tests/data/test_parquet.py b/tests/data/test_parquet.py index 634e846c1..6caeb1648 100644 --- a/tests/data/test_parquet.py +++ b/tests/data/test_parquet.py @@ -4,6 +4,7 @@ import pyarrow.dataset import pyarrow.fs import pytest +from confit.utils.random import set_seed from typing_extensions import Literal import edsnlp @@ -355,6 +356,32 @@ def test_read_shuffle_loop( ] +@pytest.mark.parametrize("num_cpu_workers", [0, 2]) +@pytest.mark.parametrize("work_unit", ["record", "fragment"]) +@pytest.mark.parametrize("shuffle", [False, "dataset", "fragment"]) +def test_read_work_unit( + num_cpu_workers, + work_unit: Literal["record", "fragment"], + shuffle: Literal[False, "dataset", "fragment"], +): + if shuffle == "dataset" and work_unit == "fragment": + pytest.skip("Dataset-level shuffle is not supported with fragment work unit") + input_dir = Path(__file__).parent.parent.resolve() / "resources" / "docs.parquet" + set_seed(42) + stream = edsnlp.data.read_parquet( + input_dir, work_unit=work_unit, shuffle=shuffle + ).set_processing( + num_cpu_workers=num_cpu_workers, + ) + stream = stream.map_batches( + lambda b: "|".join(sorted([x["note_id"] for x in b])), batch_size=1000 + ) + if work_unit == "fragment" and num_cpu_workers == 2 or num_cpu_workers == 0: + assert list(stream) == ["subfolder/doc-1|subfolder/doc-2|subfolder/doc-3"] + else: + assert list(stream) == ["subfolder/doc-1|subfolder/doc-3", "subfolder/doc-2"] + + @pytest.mark.parametrize( "num_cpu_workers,write_in_worker", [