Skip to content

Commit

Permalink
Add dist env detection via env vars (#95)
Browse files Browse the repository at this point in the history
* Add dist env detection via env vars

* minor typo fix

* Use os.environ.get

* fix mypy issues

* update

---------

Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
gkroiz and tchaton authored May 13, 2024
1 parent dc29634 commit 58b158a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/litdata/utilities/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def detect(cls) -> "_DistributedEnv":
if world_size is None or world_size == -1:
world_size = 1

world_size = int(os.environ.get("WORLD_SIZE", world_size))
global_rank = int(os.environ.get("GLOBAL_RANK", global_rank))
num_nodes = int(os.environ.get("NNODES", num_nodes))

return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)

@classmethod
Expand Down
12 changes: 12 additions & 0 deletions tests/utilities/test_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from litdata.utilities.env import _DistributedEnv


def test_distributed_env_from_env(monkeypatch):
monkeypatch.setenv("WORLD_SIZE", 2)
monkeypatch.setenv("GLOBAL_RANK", 1)
monkeypatch.setenv("NNODES", 2)

dist_env = _DistributedEnv.detect()
assert dist_env.world_size == 2
assert dist_env.global_rank == 1
assert dist_env.num_nodes == 2

0 comments on commit 58b158a

Please sign in to comment.