Skip to content

Commit

Permalink
Fix set_drop_last and test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 17, 2024
1 parent 0053486 commit 3e94cd6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def __init__(
profile_dir: Optional[str] = None,
prefetch_factor: Optional[int] = None,
shuffle: Optional[bool] = None,
drop_last: Optional[bool] = False,
drop_last: Optional[bool] = None,
collate_fn: Optional[Callable] = None,
**kwargs: Any,
) -> None: # pyright: ignore
Expand Down
44 changes: 27 additions & 17 deletions tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,36 +600,46 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)

L = len(dataset)
assert len(dataset) == L
assert L == 20

for i in range(L):
sequence = dataset[i]
assert sequence[0].item() == i * block_size
assert sequence[-1].item() == (i + 1) * block_size - 1

monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "0")
monkeypatch.setenv("NNODES", "1")
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
assert dataset.drop_last # in distributed setting, this is forced automatically

dataset.distributed_env = _DistributedEnv(2, 0, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)

assert len(dataloader) == 5

expected = [[0, 10], [20, 30], [40, 50], [60, 70], [80, 90]]
# L = 20, world size 2, num workers 2
# L / (2 * 2) = 5 items per worker
# drop last -> 4 items per worker
# batch size = 2 -> 2 batches per worker -> len(dataloader) = 4
assert len(dataloader) == 4

expected = [[0, 10], [40, 50], [20, 30], [60, 70]]
returned = []
for batch_idx, batch in enumerate(dataloader):
x = [batch[0][0].item(), batch[1][0].item()]
# breakpoint()
# assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
returned.append(batch[:, 0].tolist())
assert returned == expected

dataset.distributed_env = _DistributedEnv(2, 1, 1)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

assert len(dataloader) == 5

expected = [[100, 110], [120, 130], [140, 150], [160, 170], [180, 190]]
monkeypatch.setenv("WORLD_SIZE", "2")
monkeypatch.setenv("GLOBAL_RANK", "1")
monkeypatch.setenv("NNODES", "1")
dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataloader = StreamingDataLoader(dataset, batch_size=2, shuffle=False, num_workers=2)
assert dataset.drop_last # in distributed setting, this is forced automatically

assert len(dataloader) == 4

expected = [[80, 90], [120, 130], [100, 110], [140, 150]]
returned = []
for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]
returned.append(batch[:, 0].tolist())
assert returned == expected


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
Expand Down

0 comments on commit 3e94cd6

Please sign in to comment.