You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using a StreamingDataloader with num_workers=0 works, but resuming the state does not. There is an explicit length check for the state that fails.
Using num_workers=0 is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force having num_workers>=1. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.
What version are you seeing the problem on?
master
How to reproduce the bug
importtorchdefrun():
checkpoint_path="checkpoint.pt"# Save a checkpointtrain_dataloader=create_dataloader()
train_iterator=iter(train_dataloader)
next(train_iterator)
next(train_iterator)
torch.save(train_dataloader.state_dict(), checkpoint_path)
# Reset and attempt resumetrain_dataloader=create_dataloader()
state= {"train_dataloader": train_dataloader}
train_dataloader.load_state_dict(torch.load(checkpoint_path))
train_iterator=iter(train_dataloader)
next(train_iterator)
next(train_iterator)
defcreate_dataloader():
fromlightning.dataimportStreamingDataset, CombinedStreamingDataset, StreamingDataLoaderfromlightning.data.streaming.item_loaderimportTokensLoadertrain_datasets= [
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/slimpajama/train",
item_loader=TokensLoader(block_size=4),
),
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/starcoder",
item_loader=TokensLoader(block_size=4),
),
]
combined_dataset=CombinedStreamingDataset(datasets=train_datasets)
train_dataloader=StreamingDataLoader(combined_dataset, batch_size=4, num_workers=0) # <--- BUG WHEN NUM WORKERS=0returntrain_dataloaderif__name__=="__main__":
run()
Error messages and logs
Traceback (most recent call last):
File "/teamspace/studios/this_studio/repro_worker.py", line 50, in <module>
run()
File "/teamspace/studios/this_studio/repro_worker.py", line 25, in run
next(train_iterator)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 432, in __iter__
for batch in super().__iter__():
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
return self._get_iterator()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 504, in _get_iterator
return _SingleProcessDataLoaderIter(self)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
self._dataset_fetcher = _DatasetKind.create_fetcher(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
self.dataset_iter = iter(dataset)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 83, in __iter__
self._iterator = _CombinedDatasetIterator(
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in __init__
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in <listcomp>
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 146, in __iter__
self._validate_state_dict()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 328, in _validate_state_dict
raise ValueError(
ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): master (2.2dev)
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
The text was updated successfully, but these errors were encountered:
Bug description
Using a StreamingDataloader with
num_workers=0
works, but resuming the state does not. There is an explicit length check for the state that fails.Using
num_workers=0
is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force havingnum_workers>=1
. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
The text was updated successfully, but these errors were encountered: