From 876a65d9e54cc0c105c10bca3fd38a5811b513da Mon Sep 17 00:00:00 2001 From: gkroiz Date: Mon, 8 Apr 2024 23:11:59 +0000 Subject: [PATCH 1/5] Add dist env detection via env vars --- src/litdata/utilities/env.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 7276a87e..874e597b 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -57,7 +57,16 @@ def detect(cls) -> "_DistributedEnv": global_rank = 0 num_nodes = 1 - if world_size is None or world_size == -1: + if os.environ["WORLD_SIZE"] is not None: + world_size = os.environ["WORLD_SIZE"] + + if os.environ["GLOBAL_RANK"] is not None: + global_rank = os.environ["GLOBAL_RANK"] + + if os.environ["NNODES"] is not None: + num_nodes = os.environ["NNODES"] + + if world_size is in [None, -1, 0]: world_size = 1 return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) From 8aadba04f5c6d5ef52cfc03829bb786d55f7d8c4 Mon Sep 17 00:00:00 2001 From: gkroiz Date: Mon, 8 Apr 2024 23:27:51 +0000 Subject: [PATCH 2/5] minor typo fix --- src/litdata/utilities/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 874e597b..7d90d4de 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -66,7 +66,7 @@ def detect(cls) -> "_DistributedEnv": if os.environ["NNODES"] is not None: num_nodes = os.environ["NNODES"] - if world_size is in [None, -1, 0]: + if world_size in (None, -1, 0): world_size = 1 return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) From 99f0e7b2a0dd68698dd58f771625d45fd5a4deba Mon Sep 17 00:00:00 2001 From: gkroiz Date: Tue, 9 Apr 2024 08:34:04 +0000 Subject: [PATCH 3/5] Use os.environ.get --- src/litdata/utilities/env.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 7d90d4de..28db9518 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -57,14 +57,14 @@ def detect(cls) -> "_DistributedEnv": global_rank = 0 num_nodes = 1 - if os.environ["WORLD_SIZE"] is not None: - world_size = os.environ["WORLD_SIZE"] + if os.environ.get("WORLD_SIZE") is not None: + world_size = int(os.environ.get("WORLD_SIZE")) - if os.environ["GLOBAL_RANK"] is not None: - global_rank = os.environ["GLOBAL_RANK"] + if os.environ.get("GLOBAL_RANK") is not None: + global_rank = int(os.environ.get("GLOBAL_RANK")) - if os.environ["NNODES"] is not None: - num_nodes = os.environ["NNODES"] + if os.environ.get("NNODES") is not None: + num_nodes = int(os.environ.get("NNODES")) if world_size in (None, -1, 0): world_size = 1 From 74d044fae53c4946202066e30451467eacf949b6 Mon Sep 17 00:00:00 2001 From: gkroiz Date: Wed, 24 Apr 2024 00:17:34 +0000 Subject: [PATCH 4/5] fix mypy issues --- src/litdata/utilities/env.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 28db9518..1005c503 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -57,18 +57,13 @@ def detect(cls) -> "_DistributedEnv": global_rank = 0 num_nodes = 1 - if os.environ.get("WORLD_SIZE") is not None: - world_size = int(os.environ.get("WORLD_SIZE")) - - if os.environ.get("GLOBAL_RANK") is not None: - global_rank = int(os.environ.get("GLOBAL_RANK")) - - if os.environ.get("NNODES") is not None: - num_nodes = int(os.environ.get("NNODES")) - - if world_size in (None, -1, 0): + if world_size is None or world_size == -1: world_size = 1 + world_size = int(os.environ.get("WORLD_SIZE", world_size)) + global_rank = int(os.environ.get("GLOBAL_RANK", global_rank)) + num_nodes = int(os.environ.get("NNODES", num_nodes)) + return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) @classmethod From 264f7e1bbb79b10f48573390cbbda358d0e00307 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 13 May 2024 11:11:17 +0100 Subject: [PATCH 5/5] update --- tests/utilities/test_env.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 tests/utilities/test_env.py diff --git a/tests/utilities/test_env.py b/tests/utilities/test_env.py new file mode 100644 index 00000000..064d9380 --- /dev/null +++ b/tests/utilities/test_env.py @@ -0,0 +1,12 @@ +from litdata.utilities.env import _DistributedEnv + + +def test_distributed_env_from_env(monkeypatch): + monkeypatch.setenv("WORLD_SIZE", 2) + monkeypatch.setenv("GLOBAL_RANK", 1) + monkeypatch.setenv("NNODES", 2) + + dist_env = _DistributedEnv.detect() + assert dist_env.world_size == 2 + assert dist_env.global_rank == 1 + assert dist_env.num_nodes == 2