Skip to content

Commit

Permalink
clean up shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 16, 2024
1 parent 4130f50 commit 0053486
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
2 changes: 1 addition & 1 deletion src/litdata/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):

@lru_cache(maxsize=10)
def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int:
_, workers_intervals = self.get_chunks_and_intervals_per_workers( # TODO: rename
_, workers_intervals = self.get_chunks_and_intervals_per_workers(
distributed_env, num_workers, batch_size, current_epoch
)
worker_start = distributed_env.global_rank * num_workers
Expand Down
4 changes: 0 additions & 4 deletions src/litdata/utilities/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def _associate_chunks_and_internals_to_workers(
batch_size: int = 1,
) -> Tuple[List[List[int]], List[Any]]:
num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals])
print(f"{num_items=}")
world_size = distributed_env.world_size * num_workers
print("WORLD_SIZE=", world_size)
num_items_per_workers: List[int] = [
num_items // world_size + num_items % world_size
if rank == world_size - 1 and not drop_last
Expand All @@ -65,7 +63,6 @@ def _associate_chunks_and_internals_to_workers(
ratio = batch_size
num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers]

print(f"{num_items_per_workers=}")
chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)]
intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)]

Expand Down Expand Up @@ -105,7 +102,6 @@ def _associate_chunks_and_internals_to_workers(
num_items_per_workers[rank] -= items_in_chunk
break

# print(drop_last, batch_size, num_workers, [sum(interval[2] - interval[1] for interval in intervals) for intervals in intervals_per_workers])
return chunks_per_workers, intervals_per_workers


Expand Down

0 comments on commit 0053486

Please sign in to comment.