Skip to content

Commit

Permalink
[brief] Fixes the distributed flag when using torchrun.
Browse files Browse the repository at this point in the history
[detailed]
- Due to the order of operations, we cannot rely on the distributed info
  from the core module to determine the world size when torchrun is
  used. Since the distributed package hasn't been initialized yet, we
  have to pull the world size from the environment variable directly.
  • Loading branch information
marovira committed Apr 26, 2024
1 parent b936461 commit d3cab05
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/helios/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses
import enum
import itertools
import os
import pathlib
import re
import time
Expand Down Expand Up @@ -686,7 +687,7 @@ def _setup_device_flags(self, use_cpu: bool | None):
self._is_distributed = (
len(self._gpu_ids) > 1
if not self._is_torchrun
else dist.get_dist_info().world_size > 1
else int(os.environ["WORLD_SIZE"]) > 1
)

def _save_checkpoint(self, state: TrainingState) -> None:
Expand Down

0 comments on commit d3cab05

Please sign in to comment.