diff --git a/src/helios/trainer.py b/src/helios/trainer.py index edc70d8..21ede1e 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -3,6 +3,7 @@ import dataclasses import enum import itertools +import os import pathlib import re import time @@ -686,7 +687,7 @@ def _setup_device_flags(self, use_cpu: bool | None): self._is_distributed = ( len(self._gpu_ids) > 1 if not self._is_torchrun - else dist.get_dist_info().world_size > 1 + else int(os.environ["WORLD_SIZE"]) > 1 ) def _save_checkpoint(self, state: TrainingState) -> None: