diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py index eec664db2d..13264a336c 100644 --- a/deepmd/pt/utils/auto_batch_size.py +++ b/deepmd/pt/utils/auto_batch_size.py @@ -21,8 +21,6 @@ class AutoBatchSize(AutoBatchSizeBase): is not set factor : float, default: 2. increased factor - returned_dict: - if the batched method returns a dict of arrays. """ @@ -30,13 +28,11 @@ def __init__( self, initial_batch_size: int = 1024, factor: float = 2.0, - returned_dict: bool = False, ): super().__init__( initial_batch_size=initial_batch_size, factor=factor, ) - self.returned_dict = returned_dict def is_gpu_available(self) -> bool: """Check if GPU is available. @@ -105,9 +101,13 @@ def execute_with_batch_size( index = 0 results = None + returned_dict = None while index < total_size: n_batch, result = self.execute(execute_with_batch_size, index, natoms) - if not self.returned_dict: + returned_dict = ( + isinstance(result, dict) if returned_dict is None else returned_dict + ) + if not returned_dict: result = (result,) if not isinstance(result, tuple) else result index += n_batch @@ -116,7 +116,7 @@ def append_to_list(res_list, res): res_list.append(res) return res_list - if not self.returned_dict: + if not returned_dict: results = [] if results is None else results results = append_to_list(results, result) else: @@ -126,6 +126,8 @@ def append_to_list(res_list, res): results = { kk: append_to_list(results[kk], result[kk]) for kk in result.keys() } + assert results is not None + assert returned_dict is not None def concate_result(r): if isinstance(r[0], np.ndarray): @@ -136,7 +138,7 @@ def concate_result(r): raise RuntimeError(f"Unexpected result type {type(r[0])}") return ret - if not self.returned_dict: + if not returned_dict: r_list = [concate_result(r) for r in zip(*results)] r = tuple(r_list) if len(r) == 1: diff --git a/source/tests/pt/test_auto_batch_size.py b/source/tests/pt/test_auto_batch_size.py index 0d2f5a483e..71194e001e 100644 --- a/source/tests/pt/test_auto_batch_size.py +++ b/source/tests/pt/test_auto_batch_size.py @@ -24,7 +24,7 @@ def func(dd1): def test_execute_all_dict(self): dd0 = np.zeros((10000, 2, 1, 3, 4)) dd1 = np.ones((10000, 2, 1, 3, 4)) - auto_batch_size = AutoBatchSize(256, 2.0, returned_dict=True) + auto_batch_size = AutoBatchSize(256, 2.0) def func(dd1): return {