Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming Training w/ Streaming Dataset on DDP with Multiple Nodes Fails #248

Closed
schopra8 opened this issue Jul 21, 2024 · 8 comments
Closed
Labels
bug Something isn't working help wanted Extra attention is needed priority 0

Comments

@schopra8
Copy link

πŸ› Bug

We trained a model for several epochs on multiple nodes, and we wanted to continue training with PyTorch Lightning and LitData.

βœ… When we resume training on a single device, resumption works as expected.
βœ… When we resume training on a single node with N devices, resumption works as expected.
❌ When we resume training on multiple nodes with N devices, resumption fails.

To Reproduce

Run trainer.fit with an existing checkpoint with DDP on multiple devices:

StackTrace:

[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/train.py", line 70, in <module>
[rank7]:     main(config)
[rank7]:   File "/home/train.py", line 50, in main
[rank7]:     trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
Failures:
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 133, in __next__
[rank7]:     batch = super().__next__()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/loops/fetchers.py", line 60, in __next__
[rank7]:     batch = next(self.iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 341, in __next__
[rank7]:     out = next(self._iterator)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/pytorch_lightning/utilities/combined_loader.py", line 78, in __next__
[rank7]:     out[i] = next(self.iterators[i])
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataloader.py", line 628, in __iter__
[rank7]:     for batch in super().__iter__():
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
[rank7]:     data = self._next_data()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1326, in _next_data
[rank7]:     return self._process_data(data)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
[rank7]:     data.reraise()
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/_utils.py", line 705, in reraise
[rank7]:     raise exception
[rank7]: IndexError: Caught IndexError in DataLoader worker process 1.
[rank7]: Original Traceback (most recent call last):
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
[rank7]:     fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
[rank7]:     return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
[rank7]:     self.dataset_iter = iter(dataset)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 144, in __iter__
[rank7]:     self._iterator = _CombinedDatasetIterator(
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in __init__
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/combined.py", line 192, in <listcomp>
[rank7]:     self._dataset_iters = [iter(dataset) for dataset in datasets]
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 211, in __iter__
[rank7]:     self._resume(chunks_replica, intervals_replica)
[rank7]:   File "/home/sahil/.cache/pypoetry/virtualenvs/env-auw7Hy33-py3.10/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 272, in _resume
[rank7]:     interval = self.worker_intervals[self.chunk_index]
[rank7]: IndexError: list index out of range

Code sample

I've scrubbed my code below --

# Create dataset module
custom_data_module = CustomDataModule(config)

# Initialize the model
model = Model()

# Define the PyTorch Lightning Trainer
wandb_logger = WandbLogger(**config.wandb_logger)
device_stats_monitor = DeviceStatsMonitor()
strategy = DDPStrategy(find_unused_parameters=True)
trainer = Trainer(
    logger=wandb_logger,
    callbacks=[device_stats_monitor],
    strategy=strategy,
    max_epochs=4,
    val_check_interval=500,
    accelerator='gpu',
    devices=8,
    num_nodes=2,
    enable_progress_bar=True,
    log_every_n_steps=50,
    precision=32,
    default_root_dir='/scratch/lightning_logs'
)
trainer.fit(model, datamodule=custom_data_module, ckpt_path=ckpt)
class CustomDataModule(LightningDataModule):
    """
    Custom Data Module wraps training/validation StreamingDataset objects.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

    def setup(self, stage=None):
        if stage in (None, 'fit'):
            # Create train datasets
            self.train_datasets = []
            for ds_config in self.config.datasets.train:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.train_datasets.append(dataset)
            self.train_dataset = CombinedStreamingDataset(self.train_datasets)

            # Create validation datasets
            self.val_datasets = []
            for ds_config in self.config.datasets.val:
                dataset = StreamingDataset(
                    input_dir=ds_config.path,
                    subsample=ds_config.subsample
                )
                self.val_datasets.append(dataset)
            self.val_dataset = CombinedStreamingDataset(self.val_datasets)

    def train_dataloader(self):
        return StreamingDataLoader(
            self.train_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

    def val_dataloader(self):
        return StreamingDataLoader(
            self.val_dataset,
            collate_fn=collate_fn,
            **self.config.dataloader
        )

Expected behavior

Resume training on multiple nodes

Environment

  • PyTorch Version (e.g., 1.0): 2.3.1
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): poetry
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 12.1
  • GPU models and configuration: 2 8xH100 nodes
  • Any other relevant information:
@schopra8 schopra8 added bug Something isn't working help wanted Extra attention is needed labels Jul 21, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@tchaton
Copy link
Collaborator

tchaton commented Jul 21, 2024

Hey @schopra8,

Thanks for reporting the issue. Yes, we are aware of that :) so we are already looking into it. If we do fix it next week, we would make a new release cc @awaelchli

Unfortunately, this is coming as a series of fixes (#237) that aren't backward compatible (we won't be able to load old checkpoints as the core logic has changed too much).

@schopra8
Copy link
Author

np! thanks for the heads up

@schopra8
Copy link
Author

Also wanted to flag --

I tried running resume on 1 node with N devices. Everything worked for the first couple hundred steps, but then I hit the same error. So, it looks like there is a similar issue in DDP on 1 Node, as well.

@tchaton
Copy link
Collaborator

tchaton commented Jul 22, 2024

Hey @schopra8 Here are the release notes: https://github.com/Lightning-AI/litdata/releases/tag/v0.2.17.

Would you mind trying again with the latest version: 0.2.17 ?

Old checkpoints won't work unfortunately.

@schopra8
Copy link
Author

Awesome! I'll try in the next 1-2 days and report back my results

@tchaton
Copy link
Collaborator

tchaton commented Jul 22, 2024

Thanks @schopra8.

@schopra8
Copy link
Author

@tchaton Tested this and works! Closing this issue - since the DDP problem is solved.

We're having another issue #263 with resuming training with a new dataset. We want to preserve optimizer states, etc. when we continue training. Any guidance would be much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed priority 0
Projects
None yet
Development

No branches or pull requests

2 participants