-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Uneven number of batches returned across ranks in StreamingDataset/DataLoader #233
Labels
Comments
awaelchli
added
bug
Something isn't working
help wanted
Extra attention is needed
labels
Jul 15, 2024
Super interesting, when I run this code, I am getting on the first run. Rank 0 finished. Batches fetched: 16, length: 16
Rank 2 finished. Batches fetched: 16, length: 16
Rank 3 finished. Batches fetched: 16, length: 16
Rank 1 finished. Batches fetched: 16, length: 16 And this on the second one: Rank 0 finished. Batches fetched: 16, length: 16
Rank 2 finished. Batches fetched: 17, length: 16
Rank 3 finished. Batches fetched: 17, length: 16
Rank 1 finished. Batches fetched: 17, length: 16 Or even Rank 3 finished. Batches fetched: 17, length: 17
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:07<00:00, 2.31it/s]
Rank 0 finished. Batches fetched: 17, length: 17
Rank 2 finished. Batches fetched: 17, length: 17
Rank 1 finished. Batches fetched: 17, length: 17 wtf ... |
Rank 1 finished. Batches fetched: 18, length: 17
Rank 3 finished. Batches fetched: 18, length: 17
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:07<00:00, 2.30it/s]
Rank 0 finished. Batches fetched: 17, length: 17
Rank 2 finished. Batches fetched: 18, length: 17 |
4 tasks
@tchaton Thanks for taking a look already. The linked PR seems to solve the problem with the toy example I have here, but in the real dataset I still saw the issue occur. Here is the example with the dataset I used: import torch
import os
from lightning.fabric import Fabric
from tqdm import tqdm
from functools import partial
import tiktoken
from litdata import StreamingDataset, StreamingDataLoader, TokensLoader
from datasets import Dataset, load_dataset
from litdata import optimize
def prepare_data():
if os.path.isdir("data/openwebtext-debug/optimized"):
return
dataset = load_dataset("openwebtext", num_proc=(os.cpu_count() // 2), trust_remote_code=True)
split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
split_dataset["val"] = split_dataset.pop("test")
tokenizer = tiktoken.get_encoding("gpt2")
def tokenize(data: Dataset, index: int):
ids = tokenizer.encode_ordinary(data[index]["text"])
ids.append(tokenizer.eot_token)
yield torch.tensor(ids, dtype=torch.long)
# optimize(
# fn=partial(tokenize, split_dataset["train"]),
# inputs=list(range(len(split_dataset["train"]))),
# output_dir="data/openwebtext-debug/optimized/train",
# num_workers=64,
# chunk_bytes="200MB",
# )
optimize(
fn=partial(tokenize, split_dataset["val"]),
inputs=list(range(len(split_dataset["val"]))),
output_dir="data/openwebtext-debug/optimized/val",
num_workers=8,
chunk_bytes="200MB",
)
def get_val_dataloader():
val_dataset = StreamingDataset(
input_dir="data/openwebtext-debug/optimized/val",
item_loader=TokensLoader(block_size=1024),
shuffle=True,
drop_last=True,
)
val_dataloader = StreamingDataLoader(
val_dataset, batch_size=12, pin_memory=True, num_workers=8, drop_last=True
)
return val_dataloader
def main():
fabric = Fabric(accelerator="cpu", devices=2)
fabric.launch()
if fabric.global_rank == 0:
prepare_data()
fabric.barrier()
val_dataloader = get_val_dataloader()
fabric.barrier()
batches_fetched = 0
monitor = tqdm if fabric.global_rank == 0 else lambda x: x
for _ in monitor(val_dataloader):
batches_fetched += 1
pass
print(f"Rank {fabric.global_rank} finished. Batches fetched: {batches_fetched}, length: {len(val_dataloader)}")
fabric.barrier()
if __name__ == "__main__":
main() Output:
|
4 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🐛 Bug
The StreamingDataLoader/Dataset returns an uneven number of batches across the ranks.
Example:
Output:
As you can see, counting the batches on each rank shows uneven amounts. However, the
len(dataloader)
seems to return the correct value.The docs on
drop_last
state:At the moment, this doesn't seem to work as described.
The text was updated successfully, but these errors were encountered: