Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] : Fix resume issues with combined streaming dataset in dataloader #362

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0ba50c5
chore: Add tests for CombinedStreamingDataset in test_dataloader.py
bhimrazy Sep 3, 2024
3756b9e
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 3, 2024
f8ed272
Adds resuming for dataloading states for combined dataset case with w…
bhimrazy Sep 3, 2024
8c53791
update
bhimrazy Sep 3, 2024
e193eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
837efde
update
bhimrazy Sep 3, 2024
242a13c
Refactor dataloader to fix num_samples_yieled calculation
bhimrazy Sep 5, 2024
dff88ca
Adds more tests
bhimrazy Sep 5, 2024
bc137a3
removes the subtraction in the epoch
bhimrazy Sep 5, 2024
e38592a
update initialize part
bhimrazy Sep 5, 2024
e095407
updated epoch numbers
bhimrazy Sep 5, 2024
97510e2
format imports
bhimrazy Sep 5, 2024
5c6925d
reverted current epoch
bhimrazy Sep 6, 2024
8502bb1
removed combined data test and moved to `test_combined.py`
bhimrazy Sep 6, 2024
5cda0c6
reverted epcoh and also moved the test combined dataset
bhimrazy Sep 6, 2024
8cdafeb
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 6, 2024
560032c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
f099242
revert epoch
bhimrazy Sep 6, 2024
dc484a0
updated current epoch
bhimrazy Sep 9, 2024
53b360c
fix epoch number
bhimrazy Sep 9, 2024
0c7cf3e
updated params
bhimrazy Sep 9, 2024
8010d33
Update current_epoch in test_dataloader.py
bhimrazy Sep 9, 2024
7355532
Update num_workers in test_combined.py
bhimrazy Sep 9, 2024
d40e3ca
updated the conditions
bhimrazy Sep 9, 2024
be77b58
updated tests: added case for the complete last iteration
bhimrazy Sep 9, 2024
86ccd99
Refactor test_combined.py to fix restore state issue
bhimrazy Sep 9, 2024
206b574
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 17, 2024
5927331
fix: separated test cases for compelete and partial last epoch.
bhimrazy Sep 17, 2024
da10c3f
Merge branch 'fix/combined-dataset-loading-states' of github.com:bhim…
bhimrazy Sep 17, 2024
17de9e7
fix type errors
bhimrazy Sep 17, 2024
a05e6f1
Refactor test_combined.py: Remove print statement
bhimrazy Sep 17, 2024
3762b11
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 19, 2024
d506972
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Sep 22, 2024
868aa42
Adds conftest for combined dtaatset to reuse
bhimrazy Sep 22, 2024
357c29c
Simplified testes with parameterize to test for different conditions
bhimrazy Sep 22, 2024
ec3b840
update test
bhimrazy Sep 22, 2024
0ee0617
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2024
424bc64
Merge branch 'main' into fix/combined-dataset-loading-states
bhimrazy Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,16 @@ def state_dict(self) -> Dict[str, Any]:
"latest_worker_idx": self._latest_worker_idx,
}

num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
# Initialize a list to track the number of samples yielded for each dataset
num_samples_yieled = [0 for _ in range(len(self.dataset._datasets))]

for worker_idx in self._num_samples_yielded_combined:
for dataset_idx, samples_yieled in enumerate(self._num_samples_yielded_combined[worker_idx]):
num_samples_yieled[dataset_idx] += samples_yieled

return {
"dataset": self.dataset.state_dict(self.num_workers, self.batch_size, num_samples_yieled),
"current_epoch": self.current_epoch if self.restore else self.current_epoch - 1,
"current_epoch": self.current_epoch,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"latest_worker_idx": self._latest_worker_idx,
"num_samples_yielded": deepcopy(self._num_samples_yielded_combined),
}
Expand Down Expand Up @@ -701,21 +703,25 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:

# Inform we are resuming and disable resetting the StreamingDataLoader state.
# This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
# self.restore = True

if isinstance(self.dataset, CombinedStreamingDataset):
self.dataset._set_use_streaming_dataloader(True)
self.dataset.load_state_dict(obj)

# Inform that the dataloader is resuming.
# TODO: Check if the number of samples yielded is less than the length of the dataset.
# Also, len is not available for CombinedStreamingDataset incase of provided weights.
self.restore = True
total_samples_yielded = sum([sum(samples) for samples in self._num_samples_yielded_combined.values()])

# Check if we need to restore for the case without weights.
if self.dataset._iterate_over_all and total_samples_yielded < len(self.dataset): # type: ignore
self.restore = True

# Check if we need to restore for the case with weights.
# Note: `len` is not available for CombinedStreamingDataset in case of provided weights.
# TODO: handle the case with weights.
if not self.dataset._iterate_over_all:
self.restore = True

elif isinstance(self.dataset, StreamingDataset):
self.dataset.load_state_dict(obj["dataset"])

# Inform that the dataloader is resuming.
if self._num_samples_yielded_streaming < len(self.dataset):
self.restore = True
else:
Expand Down
30 changes: 15 additions & 15 deletions tests/streaming/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [2, 0]},
},
Expand Down Expand Up @@ -456,7 +456,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [2, 0], 1: [2, 0]},
},
Expand Down Expand Up @@ -493,7 +493,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -530,7 +530,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -567,7 +567,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
},
Expand Down Expand Up @@ -604,7 +604,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -641,7 +641,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 0,
"current_epoch": 1,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -681,7 +681,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [2, 0]},
},
Expand Down Expand Up @@ -718,7 +718,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [2, 0], 1: [2, 0]},
},
Expand Down Expand Up @@ -755,7 +755,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [2, 0], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -792,7 +792,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [3, 1], 1: [2, 0], 2: [2, 0]},
},
Expand Down Expand Up @@ -829,7 +829,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 1,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 0]},
},
Expand Down Expand Up @@ -866,7 +866,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 2,
"num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]},
},
Expand Down Expand Up @@ -903,7 +903,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
"region_of_interest": ANY,
},
},
"current_epoch": 1,
"current_epoch": 2,
"latest_worker_idx": 0,
"num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [3, 1]},
},
Expand All @@ -920,6 +920,6 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
states_23.append(dataloader.state_dict())

assert sum(not torch.equal(b1, b2) for b1, b2 in zip(batches_2[2:], batches_23)) == 0
assert states_23[0]["current_epoch"] == 1
assert states_23[0]["current_epoch"] == 2

assert not dataloader.restore
54 changes: 54 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,57 @@ def test_resume_dataloader_with_new_dataset(tmpdir):
dataloader.load_state_dict(dataloader_state)
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"


@pytest.mark.timeout(120)
def test_combined_dataset_dataloader_states(tmpdir):
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
datasets = [str(tmpdir.join(f"dataset_{i}")) for i in range(2)]
for dataset in datasets:
cache = Cache(input_dir=dataset, chunk_bytes="64MB")
for i in range(50):
cache[i] = i
cache.done()
cache.merge()

dataset_1 = StreamingDataset(datasets[0], shuffle=True)
dataset_2 = StreamingDataset(datasets[1], shuffle=True)
combined_dataset = CombinedStreamingDataset(datasets=[dataset_1, dataset_2])

# Test dataloader without explicit num workers
dataloader = StreamingDataLoader(combined_dataset, batch_size=4)
assert not dataloader.restore
dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore
batch = next(iter(dataloader))
assert len(batch) == 4, "Batch size should be 4"
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Test dataloader with num workers
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=2)
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"

# Verify dataloader state after partial iteration
for batch_idx, batch in enumerate(dataloader):
assert dataloader.current_epoch == 1, "Current epoch should be 1"
if batch_idx == 10:
break

dataloader.load_state_dict(dataloader.state_dict())
assert dataloader.restore

# Verify remaining batches in the first epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 1, "Current epoch should be 1"
count += 1
assert count == 15, "There should be atleast 15 batches remaining in the first epoch"
assert not dataloader.restore

# Verify batches in the second epoch
count = 0
for _ in dataloader:
assert dataloader.current_epoch == 2, "Current epoch should be 2"
count += 1
assert count >= 25, "There should be at least 25 batches in the second epoch"

# TODO: Add more conditions to check the state of the dataloader
Loading