From d3cab050e20c38f361d938bf75d317b0178cc3b8 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Fri, 26 Apr 2024 12:35:56 -0700 Subject: [PATCH] [brief] Fixes the distributed flag when using torchrun. [detailed] - Due to the order of operations, we cannot rely on the distributed info from the core module to determine the world size when torchrun is used. Since the distributed package hasn't been initialized yet, we have to pull the world size from the environment variable directly. --- src/helios/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/helios/trainer.py b/src/helios/trainer.py index edc70d8..21ede1e 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -3,6 +3,7 @@ import dataclasses import enum import itertools +import os import pathlib import re import time @@ -686,7 +687,7 @@ def _setup_device_flags(self, use_cpu: bool | None): self._is_distributed = ( len(self._gpu_ids) > 1 if not self._is_torchrun - else dist.get_dist_info().world_size > 1 + else int(os.environ["WORLD_SIZE"]) > 1 ) def _save_checkpoint(self, state: TrainingState) -> None: