diff --git a/src/litdata/utilities/env.py b/src/litdata/utilities/env.py index 7d90d4de..28db9518 100644 --- a/src/litdata/utilities/env.py +++ b/src/litdata/utilities/env.py @@ -57,14 +57,14 @@ def detect(cls) -> "_DistributedEnv": global_rank = 0 num_nodes = 1 - if os.environ["WORLD_SIZE"] is not None: - world_size = os.environ["WORLD_SIZE"] + if os.environ.get("WORLD_SIZE") is not None: + world_size = int(os.environ.get("WORLD_SIZE")) - if os.environ["GLOBAL_RANK"] is not None: - global_rank = os.environ["GLOBAL_RANK"] + if os.environ.get("GLOBAL_RANK") is not None: + global_rank = int(os.environ.get("GLOBAL_RANK")) - if os.environ["NNODES"] is not None: - num_nodes = os.environ["NNODES"] + if os.environ.get("NNODES") is not None: + num_nodes = int(os.environ.get("NNODES")) if world_size in (None, -1, 0): world_size = 1