Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uneven number of batches returned across ranks in StreamingDataset/DataLoader #233

Closed
awaelchli opened this issue Jul 15, 2024 · 3 comments · Fixed by #237
Closed

Uneven number of batches returned across ranks in StreamingDataset/DataLoader #233

awaelchli opened this issue Jul 15, 2024 · 3 comments · Fixed by #237
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@awaelchli
Copy link
Contributor

awaelchli commented Jul 15, 2024

🐛 Bug

The StreamingDataLoader/Dataset returns an uneven number of batches across the ranks.

Example:

import torch
from lightning.fabric import Fabric
from tqdm import tqdm
from litdata import optimize, StreamingDataLoader, StreamingDataset, TokensLoader


def tokenize(item):
    size = torch.randint(10, 20, size=(1, )).item()
    yield torch.randint(0, 1000, size=(size, ))


def get_dataloader():    
    train_dataset = StreamingDataset(
        input_dir="data/fake-data",
        item_loader=TokensLoader(block_size=10),
        # shuffle=True,
        # drop_last=True,
    )
    train_dataloader = StreamingDataLoader(
        train_dataset, 
        batch_size=2, 
        num_workers=1,
        # drop_last seems to have an influence here: 
        drop_last=True
    )
    return train_dataloader


def main():
    torch.manual_seed(42)
    fabric = Fabric(accelerator="cpu", devices=4)
    fabric.launch()

    if fabric.global_rank == 0:
        optimize(
            fn=tokenize,
            inputs=list(range(100)),
            output_dir="data/fake-data",
            num_workers=2,
            chunk_size=100,
            mode="overwrite"
        )
    fabric.barrier()
    
    train_dataloader = get_dataloader()

    # print(f"Rank {fabric.global_rank}: Length = {len(train_dataloader)}")
    fabric.barrier()
    
    print("Start fetching")
    monitor = tqdm if fabric.global_rank == 0 else lambda x: x
    
    batches_fetched = 0
    for _ in monitor(train_dataloader):
        batches_fetched += 1
        pass

    print(f"Rank {fabric.global_rank} finished. Batches fetched: {batches_fetched}, length: {len(train_dataloader)}")
    fabric.barrier()


if __name__ == "__main__":
    main()

Output:

Rank 0 finished. Batches fetched: 16, length: 16
Rank 1 finished. Batches fetched: 17, length: 16
Rank 3 finished. Batches fetched: 17, length: 16
Rank 2 finished. Batches fetched: 17, length: 16

As you can see, counting the batches on each rank shows uneven amounts. However, the len(dataloader) seems to return the correct value.

The docs on drop_last state:

drop_last: If `True`, drops the last items to ensure that
	all processes/workers return the same amount of data.
	The argument `drop_last` is set to `True` in a distributed setting
	and `False` otherwise.

At the moment, this doesn't seem to work as described.

@awaelchli awaelchli added bug Something isn't working help wanted Extra attention is needed labels Jul 15, 2024
@tchaton
Copy link
Collaborator

tchaton commented Jul 16, 2024

Super interesting, when I run this code, I am getting on the first run.

Rank 0 finished. Batches fetched: 16, length: 16
Rank 2 finished. Batches fetched: 16, length: 16
Rank 3 finished. Batches fetched: 16, length: 16
Rank 1 finished. Batches fetched: 16, length: 16

And this on the second one:

Rank 0 finished. Batches fetched: 16, length: 16
Rank 2 finished. Batches fetched: 17, length: 16
Rank 3 finished. Batches fetched: 17, length: 16
Rank 1 finished. Batches fetched: 17, length: 16

Or even

Rank 3 finished. Batches fetched: 17, length: 17
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:07<00:00,  2.31it/s]
Rank 0 finished. Batches fetched: 17, length: 17
Rank 2 finished. Batches fetched: 17, length: 17
Rank 1 finished. Batches fetched: 17, length: 17

wtf ...

@tchaton
Copy link
Collaborator

tchaton commented Jul 16, 2024

Rank 1 finished. Batches fetched: 18, length: 17
Rank 3 finished. Batches fetched: 18, length: 17
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:07<00:00,  2.30it/s]
Rank 0 finished. Batches fetched: 17, length: 17
Rank 2 finished. Batches fetched: 18, length: 17

@awaelchli
Copy link
Contributor Author

awaelchli commented Jul 16, 2024

@tchaton Thanks for taking a look already. The linked PR seems to solve the problem with the toy example I have here, but in the real dataset I still saw the issue occur. Here is the example with the dataset I used:

import torch
import os
from lightning.fabric import Fabric
from tqdm import tqdm
from functools import partial
import tiktoken
from litdata import StreamingDataset, StreamingDataLoader, TokensLoader
from datasets import Dataset, load_dataset
from litdata import optimize


def prepare_data():
    if os.path.isdir("data/openwebtext-debug/optimized"):
        return

    dataset = load_dataset("openwebtext", num_proc=(os.cpu_count() // 2), trust_remote_code=True)
    split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
    split_dataset["val"] = split_dataset.pop("test")
    tokenizer = tiktoken.get_encoding("gpt2")

    def tokenize(data: Dataset, index: int):
        ids = tokenizer.encode_ordinary(data[index]["text"])
        ids.append(tokenizer.eot_token)
        yield torch.tensor(ids, dtype=torch.long)

    # optimize(
    #     fn=partial(tokenize, split_dataset["train"]),
    #     inputs=list(range(len(split_dataset["train"]))),
    #     output_dir="data/openwebtext-debug/optimized/train",
    #     num_workers=64,
    #     chunk_bytes="200MB",
    # )
    optimize(
        fn=partial(tokenize, split_dataset["val"]),
        inputs=list(range(len(split_dataset["val"]))),
        output_dir="data/openwebtext-debug/optimized/val",
        num_workers=8,
        chunk_bytes="200MB",
    )


def get_val_dataloader():
    val_dataset = StreamingDataset(
        input_dir="data/openwebtext-debug/optimized/val",
        item_loader=TokensLoader(block_size=1024),
        shuffle=True,
        drop_last=True,
    )
    val_dataloader = StreamingDataLoader(
        val_dataset, batch_size=12, pin_memory=True, num_workers=8, drop_last=True
    )
    return val_dataloader


def main():
    fabric = Fabric(accelerator="cpu", devices=2)
    fabric.launch()

    if fabric.global_rank == 0:
        prepare_data()
    fabric.barrier()
    val_dataloader = get_val_dataloader()
    fabric.barrier()
    batches_fetched = 0
    monitor = tqdm if fabric.global_rank == 0 else lambda x: x
    for _ in monitor(val_dataloader):
        batches_fetched += 1
        pass

    print(f"Rank {fabric.global_rank} finished. Batches fetched: {batches_fetched}, length: {len(val_dataloader)}")
    fabric.barrier()


if __name__ == "__main__":
    main()

Output:

Rank 1 finished. Batches fetched: 177, length: 176                                                                                                   
Rank 0 finished. Batches fetched: 179, length: 176

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
2 participants