From 4b92b6d8a57a543f8d826e6c87dd4dac6b722566 Mon Sep 17 00:00:00 2001 From: Chun Cai Date: Thu, 28 Nov 2024 02:00:30 +0800 Subject: [PATCH] Perf: print summary on rank 0 (#4434) ## Summary by CodeRabbit - **Bug Fixes** - Adjusted the summary printing functionality to ensure it only executes from the main process in distributed settings, preventing duplicate outputs. --- deepmd/pt/utils/dataloader.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 9920622792..2fea6b72d2 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -185,19 +185,21 @@ def print_summary( name: str, prob: list[float], ) -> None: - print_summary( - name, - len(self.systems), - [ss.system for ss in self.systems], - [ss._natoms for ss in self.systems], - self.batch_sizes, - [ - ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) - for ii, ss in enumerate(self.systems) - ], - prob, - [ss._data_system.pbc for ss in self.systems], - ) + rank = dist.get_rank() if dist.is_initialized() else 0 + if rank == 0: + print_summary( + name, + len(self.systems), + [ss.system for ss in self.systems], + [ss._natoms for ss in self.systems], + self.batch_sizes, + [ + ss._data_system.get_sys_numb_batch(self.batch_sizes[ii]) + for ii, ss in enumerate(self.systems) + ], + prob, + [ss._data_system.pbc for ss in self.systems], + ) _sentinel = object()