Skip to content

Commit

Permalink
readme updated and minor bugs fixed for 'resume streamingDataloader f…
Browse files Browse the repository at this point in the history
…rom a checkpoint'
  • Loading branch information
deependujha committed Jul 9, 2024
1 parent c4c9117 commit d259cd9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 2 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,48 @@ from litdata import StreamingDataset
dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
```

</details>

<details>
<summary> ✅ Optimize dataset in distributed environment</summary>
&nbsp;

Lightning can distribute large workloads across hundreds of machines in parallel. This can reduce the time to complete a data processing task from weeks to minutes by scaling to enough machines.

To apply the optimize operator across multiple machines, simply provide the num_nodes and machine arguments to it as follows:

```python
import os
from litdata import optimize, Machine

def compress(index):
return (index, index ** 2)

optimize(
fn=compress,
inputs=list(range(100)),
num_workers=2,
output_dir="my_output",
chunk_bytes="64MB",
num_nodes=2,
machine=Machine.DATA_PREP, # You can select between dozens of optimized machines
)
```

If the `output_dir` is a local path, the optimized dataset will be present in: `/teamspace/jobs/{job_name}/nodes-0/my_output`. Otherwise, it will be stored in the specified `output_dir`.

Read the optimized dataset:

```python
from litdata import StreamingDataset

output_dir = "/teamspace/jobs/litdata-optimize-2024-07-08/nodes.0/my_output"

dataset = StreamingDataset(output_dir)

print(dataset[:])
```

</details>

&nbsp;
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __iter__(self) -> Iterator[Any]:
num_samples_yielded = None

if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded:
num_samples_yielded = self._num_samples_yielded[worker_env.rank]
num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0)

self._iterator = _CombinedDatasetIterator(
self._datasets,
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
self._num_samples_yielded_combined = obj["num_samples_yielded"]

# Used to restart on the next DataLoader worker from the previous run.
self._latest_worker_idx = obj["latest_worker_idx"] + 1
self._latest_worker_idx = (obj["latest_worker_idx"] + 1) % (self.num_workers if self.num_workers > 0 else 1)
self._worker_idx_iter = iter(self._worker_idx)
for _ in range(self._latest_worker_idx):
next(self._worker_idx_iter)
Expand Down
2 changes: 2 additions & 0 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
# the state is restored within the workers
self._state_dict = state_dict
self.subsampled_files = state_dict["subsampled_files"]
self.region_of_interest = state_dict["region_of_interest"]

def _validate_state_dict(self) -> None:
assert self._state_dict
Expand Down

0 comments on commit d259cd9

Please sign in to comment.