Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Updates readme and a few nitpicks #223

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,48 @@ from litdata import StreamingDataset
dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
```

</details>

<details>
<summary> ✅ Optimize dataset in distributed environment</summary>
&nbsp;

Lightning can distribute large workloads across hundreds of machines in parallel. This can reduce the time to complete a data processing task from weeks to minutes by scaling to enough machines.

To apply the optimize operator across multiple machines, simply provide the num_nodes and machine arguments to it as follows:

```python
import os
from litdata import optimize, Machine

def compress(index):
return (index, index ** 2)

optimize(
fn=compress,
inputs=list(range(100)),
num_workers=2,
output_dir="my_output",
chunk_bytes="64MB",
num_nodes=2,
machine=Machine.DATA_PREP, # You can select between dozens of optimized machines
)
```

If the `output_dir` is a local path, the optimized dataset will be present in: `/teamspace/jobs/{job_name}/nodes-0/my_output`. Otherwise, it will be stored in the specified `output_dir`.

Read the optimized dataset:

```python
from litdata import StreamingDataset

output_dir = "/teamspace/jobs/litdata-optimize-2024-07-08/nodes.0/my_output"

dataset = StreamingDataset(output_dir)

print(dataset[:])
```

</details>

&nbsp;
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __iter__(self) -> Iterator[Any]:
num_samples_yielded = None

if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded:
num_samples_yielded = self._num_samples_yielded[worker_env.rank]
num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0)

self._iterator = _CombinedDatasetIterator(
self._datasets,
Expand Down
5 changes: 4 additions & 1 deletion tests/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def _get_simulated_s3_dataloader(cache_dir, data_dir):
return StreamingDataLoader(dataset, batch_size=2, num_workers=1)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs")
@mock.patch.dict(os.environ, {}, clear=True)
def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
"""This test is constructed to test resuming from a chunk past the first chunk, when subsequent chunks don't have
Expand All @@ -830,6 +830,7 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
batches_to_fetch = 16
batch_to_resume_from = None
dataloader_state = None
for i, batch in enumerate(train_dataloader):
if i == batches_to_fetch:
dataloader_state = train_dataloader.state_dict()
Expand All @@ -840,6 +841,8 @@ def test_dataset_resume_on_future_chunks(tmpdir, monkeypatch):
shutil.rmtree(s3_cache_dir)
os.mkdir(s3_cache_dir)
train_dataloader = _get_simulated_s3_dataloader(s3_cache_dir, data_dir)
assert dataloader_state is not None
assert batch_to_resume_from is not None
train_dataloader.load_state_dict(dataloader_state)
# The next batch after resuming must match what we should have gotten next in the initial loop
assert torch.equal(next(iter(train_dataloader)), batch_to_resume_from)
Expand Down
Loading