Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streams #350

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@

# Unreleased

### Added

- `edsnlp.data.read_parquet` now accept a `work_unit="fragment"` option to split tasks between workers by parquet fragment instead of row. When this is enabled, workers do not read every fragment while skipping 1 in n rows, but read all rows of 1/n fragments, which should be faster.

### Fixed

- Fix `join_thread` missing attribute in `SimpleQueue` when cleaning a multiprocessing executor
- Support huggingface transformers that do not set `cls_token_id` and `sep_token_id` (we now also look for these tokens in the `special_tokens_map` and `vocab` mappings)
- Fix changing scorers dict size issue when evaluating during training
- Seed random states (instead of using `random.RandomState()`) when shuffling in data readers : this is important for
1. reproducibility
2. in multiprocessing mode, ensure that the same data is shuffled in the same way in all workers
- Bubble BaseComponent instantiation errors correctly

## v0.14.0 (2024-11-14)

Expand Down
6 changes: 6 additions & 0 deletions edsnlp/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,12 @@ def shuffle(
else False
)
stream = self
# Ensure that we have a "deterministic" random seed, meaning
# if the user sets a global seed before in the program and execute the
# same program twice, the shuffling should be the same in both cases.
# This is not garanteed by just creating random.Random() which does not
# account
seed = seed if seed is not None else random.getrandbits(32)
if shuffle_reader:
if shuffle_reader not in self.reader.emitted_sentinels:
raise ValueError(f"Cannot shuffle by {shuffle_reader}")
Expand Down
21 changes: 4 additions & 17 deletions edsnlp/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,6 @@


class BaseReader:
"""
The BaseReader servers as a base class for all readers. It expects two methods:

- `read_records` method which is called in the main process and should return a
generator of fragments (like filenames) with their estimated size (number of
documents)
- `unpack_tasks` method which is called in the worker processes and receives
batches of fragments and should return a list of dictionaries (one per
document), ready to be converted to a Doc object by the converter.

Additionally, the subclass should define a `DATA_FIELDS` class attribute which
contains the names of all attributes that should not be copied when the reader is
copied to the worker processes. This is useful for example when the reader holds a
reference to a large object like a DataFrame that should not be copied to the
worker processes.
"""

DATA_FIELDS: Tuple[str] = ()
read_in_worker: bool
emitted_sentinels: set
Expand All @@ -47,6 +30,9 @@ class BaseReader:
def read_records(self) -> Iterable[Any]:
raise NotImplementedError()

def extract_task(self, item):
return [item]

def worker_copy(self):
if self.read_in_worker:
return self
Expand Down Expand Up @@ -110,6 +96,7 @@ def __init__(
):
super().__init__()
self.shuffle = shuffle
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.emitted_sentinels = {"dataset"}
self.loop = loop
Expand Down
2 changes: 1 addition & 1 deletion edsnlp/data/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
):
super().__init__()
self.shuffle = shuffle
self.rng = random.Random(seed)
self.write_in_worker = write_in_worker
self.emitted_sentinels = {"dataset"}
self.loop = loop
Expand All @@ -57,6 +56,7 @@ def __init__(
self.keep_ipynb_checkpoints = keep_ipynb_checkpoints
self.shuffle = shuffle
self.loop = loop
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
for file in self.files:
if not self.fs.exists(file):
Expand Down
1 change: 1 addition & 0 deletions edsnlp/data/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
):
super().__init__()
self.shuffle = shuffle
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.emitted_sentinels = {"dataset"}
self.loop = loop
Expand Down
52 changes: 45 additions & 7 deletions edsnlp/data/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,21 @@ def __init__(
shuffle: Literal["dataset", "fragment", False] = False,
seed: Optional[int] = None,
loop: bool = False,
work_unit: Literal["record", "fragment"] = "record",
):
super().__init__()
self.shuffle = shuffle
self.emitted_sentinels = {"dataset"} | (
set() if shuffle == "dataset" else {"fragment"}
)
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.loop = loop
self.work_unit = work_unit
assert not (work_unit == "fragment" and shuffle == "dataset"), (
"Cannot shuffle at the dataset level and dispatch tasks at the "
"fragment level. Set shuffle='fragment' or work_unit='record'."
)
# Either the filesystem has not been passed
# or the path is a URL (e.g. s3://) => we need to infer the filesystem
self.fs, self.path = normalize_fs_path(filesystem, path)
Expand All @@ -52,24 +59,41 @@ def __init__(
)

def read_fragment(self, fragment: ParquetFileFragment) -> Iterable[Dict]:
return dl_to_ld(fragment.to_table().to_pydict())
return (
doc
for batch in fragment.scanner().to_reader()
for doc in dl_to_ld(batch.to_pydict())
)

def extract_task(self, item):
if self.work_unit == "fragment":
records = self.read_fragment(item)
if self.shuffle == "fragment":
records = shuffle(records, self.rng)
yield from records
else:
yield item

def read_records(self) -> Iterable[Any]:
while True:
files = self.fragments
if self.shuffle == "fragment":
for file in shuffle(files, self.rng):
records = shuffle(self.read_fragment(file), self.rng)
yield from records
if self.work_unit == "fragment":
yield file
else:
yield from shuffle(self.read_fragment(file), self.rng)
yield FragmentEndSentinel(file.path)
elif self.shuffle == "dataset":
assert self.work_unit == "record"
records = (line for file in files for line in self.read_fragment(file))
records = shuffle(records, self.rng)
yield from records
yield from shuffle(records, self.rng)
else:
for file in files:
records = list(self.read_fragment(file))
yield from records
if self.work_unit == "fragment":
yield file
else:
yield from self.read_fragment(file)
yield FragmentEndSentinel(file.path)
yield DatasetEndSentinel()
if not self.loop:
Expand Down Expand Up @@ -158,6 +182,7 @@ def read_parquet(
shuffle: Literal["dataset", "fragment", False] = False,
seed: Optional[int] = None,
loop: bool = False,
work_unit: Literal["record", "fragment"] = "record",
**kwargs,
) -> Stream:
"""
Expand Down Expand Up @@ -211,6 +236,18 @@ def read_parquet(
The seed to use for shuffling.
loop: bool
Whether to loop over the data indefinitely.
work_unit: Literal["record", "fragment"]
Only affects the multiprocessing mode. If "record", every worker will start to
read the same parquet file and yield each every num_workers-th record, starting
at an offset each. For instance, if num_workers=2, the first worker will read
the 1st, 3rd, 5th, ... records, while the second worker will read the 2nd, 4th,
6th, ... records of the first parquet file.

If "fragment", each worker will read a different parquet file. For instance, the
first worker will every record of the 1st parquet file, the second worker will
read every record of the 2nd parquet file, and so on. This way, no record is
"wasted" and every record loaded in memory is yielded.

converter: Optional[AsList[Union[str, Callable]]]
Converters to use to convert the parquet rows of the data source to Doc objects
These are documented on the [Converters](/data/converters) page.
Expand All @@ -237,6 +274,7 @@ def read_parquet(
shuffle=shuffle,
seed=seed,
loop=loop,
work_unit=work_unit,
)
)
if converter:
Expand Down
1 change: 1 addition & 0 deletions edsnlp/data/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
super().__init__()
self.shuffle = shuffle
self.emitted_sentinels = {"dataset"}
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.loop = loop

Expand Down
1 change: 1 addition & 0 deletions edsnlp/data/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.data = data
self.shuffle = shuffle
self.emitted_sentinels = {"dataset"}
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.loop = loop
assert isinstance(
Expand Down
1 change: 1 addition & 0 deletions edsnlp/data/standoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def __init__(
super().__init__()
self.shuffle = shuffle
self.emitted_sentinels = {"dataset"}
seed = seed if seed is not None else random.getrandbits(32)
self.rng = random.Random(seed)
self.loop = loop
self.fs, self.path = normalize_fs_path(filesystem, path)
Expand Down
26 changes: 16 additions & 10 deletions edsnlp/pipes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs):
# If this component is missing the nlp argument, we curry it with the
# provided arguments and return a CurriedFactory object.
sig = inspect.signature(cls.__init__)
bound = sig.bind_partial(None, nlp, *args, **kwargs)
bound.arguments.pop("self", None)
if (
"nlp" in sig.parameters
and sig.parameters["nlp"].default is sig.empty
and bound.arguments.get("nlp", sig.empty) is sig.empty
):
return CurriedFactory(cls, bound.arguments)
if nlp is inspect.Signature.empty:
bound.arguments.pop("nlp", None)
try:
bound = sig.bind_partial(None, nlp, *args, **kwargs)
bound.arguments.pop("self", None)
if (
"nlp" in sig.parameters
and sig.parameters["nlp"].default is sig.empty
and bound.arguments.get("nlp", sig.empty) is sig.empty
):
return CurriedFactory(cls, bound.arguments)
if nlp is inspect.Signature.empty:
bound.arguments.pop("nlp", None)
except TypeError: # pragma: no cover
if nlp is inspect.Signature.empty:
super().__call__(*args, **kwargs)
else:
super().__call__(nlp, *args, **kwargs)
return super().__call__(**bound.arguments)


Expand Down
9 changes: 6 additions & 3 deletions edsnlp/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,16 @@ def iter_tasks(self, stage, stop_mode=False):
return
task_idx = 0
for item in iter(self.stream.reader.read_records()):
if self.stop:
if self.stop: # pragma: no cover
raise StopSignal()
if isinstance(item, StreamSentinel):
yield item
continue
if task_idx % pool_size == worker_idx:
yield item
for item in self.stream.reader.extract_task(item):
if self.stop: # pragma: no cover
raise StopSignal()
yield item

task_idx += 1
else:
Expand Down Expand Up @@ -751,7 +754,7 @@ def iter_tasks(self, stage, stop_mode=False):
]
# Get items from the previous stage
while self.num_producers_alive[stage] > 0:
if self.stop and not stop_mode:
if self.stop and not stop_mode: # pragma: no cover
raise StopSignal()

offset = (offset + 1) % len(queues)
Expand Down
9 changes: 9 additions & 0 deletions edsnlp/processing/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def process():

with bar, stream.eval():
items = reader.read_records()
items = (
task
for item in items
for task in (
(item,)
if isinstance(item, StreamSentinel)
else reader.extract_task(item)
)
)

for stage_idx, stage in enumerate(stages):
for op in stage.cpu_ops:
Expand Down
4 changes: 3 additions & 1 deletion edsnlp/processing/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def getActiveSession() -> Optional["SparkSession"]: # pragma: no cover
df = reader.data
assert not reader.loop, "Looping is not supported with Spark backend."
df = (
df.sample(fraction=1.0, seed=reader.rng.randbytes(32))
df.sample(fraction=1.0, seed=reader.rng.getrandbits(32))
if reader.shuffle
else df
)
Expand Down Expand Up @@ -120,6 +120,8 @@ def process_partition(items): # pragma: no cover
else:
items = (item.asDict(recursive=True) for item in items)

items = (task for item in items for task in stream.reader.extract_task(item))

for stage_idx, stage in enumerate(stages):
for op in stage.cpu_ops:
items = op(items)
Expand Down
27 changes: 27 additions & 0 deletions tests/data/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyarrow.dataset
import pyarrow.fs
import pytest
from confit.utils.random import set_seed
from typing_extensions import Literal

import edsnlp
Expand Down Expand Up @@ -355,6 +356,32 @@ def test_read_shuffle_loop(
]


@pytest.mark.parametrize("num_cpu_workers", [0, 2])
@pytest.mark.parametrize("work_unit", ["record", "fragment"])
@pytest.mark.parametrize("shuffle", [False, "dataset", "fragment"])
def test_read_work_unit(
num_cpu_workers,
work_unit: Literal["record", "fragment"],
shuffle: Literal[False, "dataset", "fragment"],
):
if shuffle == "dataset" and work_unit == "fragment":
pytest.skip("Dataset-level shuffle is not supported with fragment work unit")
input_dir = Path(__file__).parent.parent.resolve() / "resources" / "docs.parquet"
set_seed(42)
stream = edsnlp.data.read_parquet(
input_dir, work_unit=work_unit, shuffle=shuffle
).set_processing(
num_cpu_workers=num_cpu_workers,
)
stream = stream.map_batches(
lambda b: "|".join(sorted([x["note_id"] for x in b])), batch_size=1000
)
if work_unit == "fragment" and num_cpu_workers == 2 or num_cpu_workers == 0:
assert list(stream) == ["subfolder/doc-1|subfolder/doc-2|subfolder/doc-3"]
else:
assert list(stream) == ["subfolder/doc-1|subfolder/doc-3", "subfolder/doc-2"]


@pytest.mark.parametrize(
"num_cpu_workers,write_in_worker",
[
Expand Down
Loading