diff --git a/src/litdata/streaming/shuffle.py b/src/litdata/streaming/shuffle.py index 357fa985..fb5ed25b 100644 --- a/src/litdata/streaming/shuffle.py +++ b/src/litdata/streaming/shuffle.py @@ -35,7 +35,7 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int: - _, workers_intervals = self.get_chunks_and_intervals_per_workers( # TODO: rename + _, workers_intervals = self.get_chunks_and_intervals_per_workers( distributed_env, num_workers, batch_size, current_epoch ) worker_start = distributed_env.global_rank * num_workers diff --git a/src/litdata/utilities/shuffle.py b/src/litdata/utilities/shuffle.py index cbd9e662..4684838a 100644 --- a/src/litdata/utilities/shuffle.py +++ b/src/litdata/utilities/shuffle.py @@ -52,9 +52,7 @@ def _associate_chunks_and_internals_to_workers( batch_size: int = 1, ) -> Tuple[List[List[int]], List[Any]]: num_items = sum([(interval[2] - interval[1]) for interval in chunk_intervals]) - print(f"{num_items=}") world_size = distributed_env.world_size * num_workers - print("WORLD_SIZE=", world_size) num_items_per_workers: List[int] = [ num_items // world_size + num_items % world_size if rank == world_size - 1 and not drop_last @@ -65,7 +63,6 @@ def _associate_chunks_and_internals_to_workers( ratio = batch_size num_items_per_workers = [ratio * int(item // ratio) for item in num_items_per_workers] - print(f"{num_items_per_workers=}") chunks_per_workers: List[List[int]] = [[] for _ in range(world_size)] intervals_per_workers: List[List[List[int]]] = [[] for _ in range(world_size)] @@ -105,7 +102,6 @@ def _associate_chunks_and_internals_to_workers( num_items_per_workers[rank] -= items_in_chunk break - # print(drop_last, batch_size, num_workers, [sum(interval[2] - interval[1] for interval in intervals) for intervals in intervals_per_workers]) return chunks_per_workers, intervals_per_workers