Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming StreamingDataloader with num_workers=0 fails #24

Open
tchaton opened this issue Feb 26, 2024 · 3 comments
Open

Resuming StreamingDataloader with num_workers=0 fails #24

tchaton opened this issue Feb 26, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@tchaton
Copy link
Collaborator

tchaton commented Feb 26, 2024

Bug description

Using a StreamingDataloader with num_workers=0 works, but resuming the state does not. There is an explicit length check for the state that fails.

Using num_workers=0 is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force having num_workers>=1. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.

What version are you seeing the problem on?

master

How to reproduce the bug

import torch


def run():
    checkpoint_path = "checkpoint.pt"

    # Save a checkpoint
    train_dataloader = create_dataloader()
    train_iterator = iter(train_dataloader)
    next(train_iterator)
    next(train_iterator)
    torch.save(train_dataloader.state_dict(), checkpoint_path)

    # Reset and attempt resume
    train_dataloader = create_dataloader()
    state = {"train_dataloader": train_dataloader}
    train_dataloader.load_state_dict(torch.load(checkpoint_path))
    train_iterator = iter(train_dataloader)
    next(train_iterator)
    next(train_iterator)


def create_dataloader():
    from lightning.data import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
    from lightning.data.streaming.item_loader import TokensLoader

    train_datasets = [
        StreamingDataset(
            input_dir="/teamspace/s3_connections/tinyllama-template/slimpajama/train",
            item_loader=TokensLoader(block_size=4),
        ),
        StreamingDataset(
            input_dir="/teamspace/s3_connections/tinyllama-template/starcoder",
            item_loader=TokensLoader(block_size=4),
        ),
    ]
    combined_dataset = CombinedStreamingDataset(datasets=train_datasets)
    train_dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=0)  # <--- BUG WHEN NUM WORKERS=0
    return train_dataloader


if __name__ == "__main__":
    run()

Error messages and logs

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/repro_worker.py", line 50, in <module>
    run()
  File "/teamspace/studios/this_studio/repro_worker.py", line 25, in run
    next(train_iterator)
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 432, in __iter__
    for batch in super().__iter__():
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 504, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 83, in __iter__
    self._iterator = _CombinedDatasetIterator(
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in __init__
    self._dataset_iters = [iter(dataset) for dataset in datasets]
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in <listcomp>
    self._dataset_iters = [iter(dataset) for dataset in datasets]
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 146, in __iter__
    self._validate_state_dict()
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 328, in _validate_state_dict
    raise ValueError(
ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): master (2.2dev)
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

Moved from Lightning-AI/pytorch-lightning#19335, submitted by @awaelchli

@tchaton tchaton added enhancement New feature or request help wanted Extra attention is needed labels Feb 26, 2024
@awaelchli awaelchli added bug Something isn't working and removed enhancement New feature or request help wanted Extra attention is needed labels Feb 28, 2024
@shreyanssethi
Copy link

Following up on this, what is the recommended practice/solution? I was able to load my checkpoint and manually adjust it from num_workers=0 to 1 and save it again, to get it to pass the check when it loads the state_dict but wanted to know if there's a better work around

@lukasschmit
Copy link

Hi, I've also encountered this issue but with non-zero num_workers. Even more bizarre is I get this 75% of the way through a validation epoch when training with the latest pytorch lightning trainer, and AFTER successfully getting through the first validation epoch & even saving a .ckpt!

any ideas/updates on what's going on here?

Traceback (most recent call last):
17:42:49 File "/valohai/repository/ml/polle/train_polle.py", line 184, in train
17:42:49 trainer.fit(lightning_module, datamodule=datamodule, ckpt_path=polle_path)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
17:42:49 call._call_and_handle_interrupt(
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
17:42:49 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
17:42:49 return function(*args, **kwargs)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
17:42:49 self._run(model, ckpt_path=ckpt_path)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
17:42:49 results = self._run_stage()
17:42:49 ^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
17:42:49 self.fit_loop.run()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
17:42:49 self.advance()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
17:42:49 self.epoch_loop.run(self._data_fetcher)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
17:42:49 self.advance(data_fetcher)
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/training_epoch_loop.py", line 212, in advance
17:42:49 batch, _, __ = next(data_fetcher)
17:42:49 ^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
17:42:49 batch = super().__next__()
17:42:49 ^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
17:42:49 batch = next(self.iterator)
17:42:49 ^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
17:42:49 out = next(self._iterator)
17:42:49 ^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/utilities/combined_loader.py", line 78, in __next__
17:42:49 out[i] = next(self.iterators[i])
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataloader.py", line 620, in __iter__
17:42:49 for batch in super().__iter__():
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 631, in __next__
17:42:49 data = self._next_data()
17:42:49 ^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
17:42:49 return self._process_data(data)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
17:42:49 data.reraise()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/_utils.py", line 705, in reraise
17:42:49 raise exception
17:42:49ValueError: Caught ValueError in DataLoader worker process 16.
17:42:49Original Traceback (most recent call last):
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
17:42:49 fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
17:42:49 return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
17:42:49 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
17:42:49 self.dataset_iter = iter(dataset)
17:42:49 ^^^^^^^^^^^^^
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataset.py", line 199, in __iter__
17:42:49 self._validate_state_dict()
17:42:49 File "/usr/local/lib/python3.11/dist-packages/litdata/streaming/dataset.py", line 388, in _validate_state_dict
17:42:49 raise ValueError(
17:42:49ValueError: The provided `num_workers` state doesn't match the current one. Found `24` instead of `16`.

@tchaton
Copy link
Collaborator Author

tchaton commented Jul 23, 2024

Hey @ukasschmit, LitData doesn't support changing the number of workers when resuming. Can you restart with 16 workers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants