From e124bd0ee3ea9c9c93029a266691795f6bd68cc0 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:52:14 -0700 Subject: [PATCH] [brief] Better handling for printing in distributed training. [detailed] - Ensure only rank 0 prints messages to the screen. - Arguably I could move this function to the core.logging module, but I'm not entirely sure yet that I want something like that there. --- src/helios/trainer.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/helios/trainer.py b/src/helios/trainer.py index 21ede1e..1e70667 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -567,14 +567,13 @@ def _print_header( """Print the Helios header with system info to the logs.""" root_logger = logging.get_root_logger() - if self.rank == 0: - print(core.get_env_info_str()) + self._print(core.get_env_info_str()) if for_training: if chkpt_path is not None: msg = f"Resuming training from checkpoint {str(chkpt_path)}" root_logger.info(msg) - print(f"{msg}\n") + self._print(f"{msg}\n") else: root_logger.info(core.get_env_info_str()) else: @@ -585,7 +584,7 @@ def _print_header( else "Testing from loaded model" ) root_logger.info(msg) - print(f"{msg}\n") + self._print(f"{msg}\n") def _validate_flags(self): """Ensure that all the settings and flags are valid.""" @@ -1019,3 +1018,18 @@ def _validate(self, val_cycle: int) -> None: if self._is_distributed: td.barrier() + + def _print(self, *args: typing.Any, **kwargs: typing.Any) -> None: + """ + Wrap Python's print function for distributed training. + + Specifically, this function will ensure that only rank 0 prints messages to the + screen. All other ranks will do nothing. + + Args: + args: named arguments for print. + kwargs: keyword arguments for print. + """ + if self.rank != 0: + return + print(*args, **kwargs)