Skip to content

Commit

Permalink
[brief] Better handling for printing in distributed training.
Browse files Browse the repository at this point in the history
[detailed]
- Ensure only rank 0 prints messages to the screen.
- Arguably I could move this function to the core.logging module, but
  I'm not entirely sure yet that I want something like that there.
  • Loading branch information
marovira committed Apr 26, 2024
1 parent d3cab05 commit e124bd0
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/helios/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,14 +567,13 @@ def _print_header(
"""Print the Helios header with system info to the logs."""
root_logger = logging.get_root_logger()

if self.rank == 0:
print(core.get_env_info_str())
self._print(core.get_env_info_str())

if for_training:
if chkpt_path is not None:
msg = f"Resuming training from checkpoint {str(chkpt_path)}"
root_logger.info(msg)
print(f"{msg}\n")
self._print(f"{msg}\n")
else:
root_logger.info(core.get_env_info_str())
else:
Expand All @@ -585,7 +584,7 @@ def _print_header(
else "Testing from loaded model"
)
root_logger.info(msg)
print(f"{msg}\n")
self._print(f"{msg}\n")

def _validate_flags(self):
"""Ensure that all the settings and flags are valid."""
Expand Down Expand Up @@ -1019,3 +1018,18 @@ def _validate(self, val_cycle: int) -> None:

if self._is_distributed:
td.barrier()

def _print(self, *args: typing.Any, **kwargs: typing.Any) -> None:
"""
Wrap Python's print function for distributed training.
Specifically, this function will ensure that only rank 0 prints messages to the
screen. All other ranks will do nothing.
Args:
args: named arguments for print.
kwargs: keyword arguments for print.
"""
if self.rank != 0:
return
print(*args, **kwargs)

0 comments on commit e124bd0

Please sign in to comment.