Skip to content

Commit

Permalink
check if a dict is returned at runtime.
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Mar 31, 2024
1 parent baedfb1 commit 3c4ff2c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 9 additions & 7 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,18 @@ class AutoBatchSize(AutoBatchSizeBase):
is not set
factor : float, default: 2.
increased factor
returned_dict:
if the batched method returns a dict of arrays.
"""

def __init__(

Check warning on line 27 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L27

Added line #L27 was not covered by tests
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
returned_dict: bool = False,
):
super().__init__(

Check warning on line 32 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L32

Added line #L32 was not covered by tests
initial_batch_size=initial_batch_size,
factor=factor,
)
self.returned_dict = returned_dict

def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Expand Down Expand Up @@ -105,9 +101,13 @@ def execute_with_batch_size(

index = 0
results = None
returned_dict = None

Check warning on line 104 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L103-L104

Added lines #L103 - L104 were not covered by tests
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not self.returned_dict:
returned_dict = (

Check warning on line 107 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L107

Added line #L107 was not covered by tests
isinstance(result, dict) if returned_dict is None else returned_dict
)
if not returned_dict:
result = (result,) if not isinstance(result, tuple) else result

Check warning on line 111 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L110-L111

Added lines #L110 - L111 were not covered by tests
index += n_batch

Expand All @@ -116,7 +116,7 @@ def append_to_list(res_list, res):
res_list.append(res)
return res_list

Check warning on line 117 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L114-L117

Added lines #L114 - L117 were not covered by tests

if not self.returned_dict:
if not returned_dict:
results = [] if results is None else results
results = append_to_list(results, result)

Check warning on line 121 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L119-L121

Added lines #L119 - L121 were not covered by tests
else:
Expand All @@ -126,6 +126,8 @@ def append_to_list(res_list, res):
results = {

Check warning on line 126 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L126

Added line #L126 was not covered by tests
kk: append_to_list(results[kk], result[kk]) for kk in result.keys()
}
assert results is not None
assert returned_dict is not None

Check warning on line 130 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L129-L130

Added lines #L129 - L130 were not covered by tests

def concate_result(r):

Check warning on line 132 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L132

Added line #L132 was not covered by tests
if isinstance(r[0], np.ndarray):
Expand All @@ -136,7 +138,7 @@ def concate_result(r):
raise RuntimeError(f"Unexpected result type {type(r[0])}")
return ret

Check warning on line 139 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L139

Added line #L139 was not covered by tests

if not self.returned_dict:
if not returned_dict:
r_list = [concate_result(r) for r in zip(*results)]
r = tuple(r_list)
if len(r) == 1:

Check warning on line 144 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L141-L144

Added lines #L141 - L144 were not covered by tests
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/test_auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def func(dd1):
def test_execute_all_dict(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0, returned_dict=True)
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return {
Expand Down

0 comments on commit 3c4ff2c

Please sign in to comment.