Skip to content

Commit

Permalink
clean up dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Jul 16, 2024
1 parent c3edbb4 commit 99ca280
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 55 deletions.
63 changes: 9 additions & 54 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,20 @@ def __iter__(self) -> "StreamingDataset":
)

worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
# print(f"{worker_rank=}, len: {len(workers_chunks)}")
worker_chunks = workers_chunks[worker_rank]
worker_intervals = workers_intervals[worker_rank]
self.worker_chunks = workers_chunks[worker_rank]
self.worker_intervals = workers_intervals[worker_rank]

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)

# Handle restart
if self._state_dict:
# breakpoint()
self._resume(workers_chunks, workers_intervals)
else:
# TODO: Reimplement this logic
# Find the chunks shared across multiple ranks.
# For each shared chunk, find the rank to use the chunk last and prevent deletion
# for the other ranks.
# TODO better name for worker_start end
# TODO: reimplement this logic
# worker_start = self.distributed_env.global_rank * self.num_workers
# worker_end = worker_start + self.num_workers
# chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion(
Expand All @@ -234,9 +234,6 @@ def __iter__(self) -> "StreamingDataset":
# self.distributed_env.global_rank
# ]

self.worker_chunks = worker_chunks
self.worker_intervals = worker_intervals

self.num_chunks = len(self.worker_chunks)
self.current_indexes = []
self.chunk_index = 0
Expand Down Expand Up @@ -266,14 +263,10 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No

# replay sampling from each worker / chunks using the batch size
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)

# print(f"indexes1 = {indexes}")

# TODO: Change _replay_chunks_sampling to accept a list
chunks_index, indexes = _replay_chunks_sampling(
{i: workers_intervals[i] for i in range(worker_start, worker_end)}, indexes
workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)},
indexes=indexes,
)
# print(f"{indexes=}, {chunks_index=}")

# select the chunks and intervals associated to this worker
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
Expand All @@ -295,7 +288,6 @@ def _resume(self, workers_chunks: List[int], workers_intervals: List[Any]) -> No
current_indexes = current_indexes[indexes[worker_local_rank] :]
self.current_indexes = current_indexes

# print(f"currentindexes = {current_indexes}")
self.global_index = indexes[worker_local_rank]

# bump the chunk_index
Expand All @@ -317,16 +309,10 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
# print(torch.distributed.get_rank(), self.global_index, len(self), self.stop_length)
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
# print(f"{worker_rank}, {self.global_index=}, {stop_length=}")
# TODO: This is stopping too early, length is not correct
if self.global_index >= stop_length:
if self.global_index >= self.stop_length:
self.current_epoch += 1
raise StopIteration

# print(f"{self.num_chunks=}")
# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == self.num_chunks:
Expand Down Expand Up @@ -487,37 +473,6 @@ def is_integer(value: str) -> bool:
except Exception:
return False


# TODO: remove
def _associate_chunks_to_workers(
worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any]
) -> Any:
workers_chunks = {}
workers_intervals = {}

for worker_idx in range(worker_env.world_size):
worker_chunks = []
worker_intervals = []
for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)):
if i % worker_env.world_size != worker_idx:
continue

worker_chunks.append(chunk_index)
worker_intervals.append(chunk_interval)

workers_chunks[worker_idx] = worker_chunks
workers_intervals[worker_idx] = worker_intervals

# print(
# "associate",
# [
# sum(interval[2] - interval[1] for interval in intervals)
# for worker_id, intervals in workers_intervals.items()
# ],
# )
return workers_chunks, workers_intervals


def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]:
"""This function replays the sampling from the dataloader."""
divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size)
Expand Down
1 change: 0 additions & 1 deletion tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
_INDEX_FILENAME,
Dir,
StreamingDataset,
_associate_chunks_to_workers,
_replay_chunks_sampling,
_replay_sampling,
)
Expand Down

0 comments on commit 99ca280

Please sign in to comment.