Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 9, 2024
1 parent 3baa28c commit bb47541
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
23 changes: 14 additions & 9 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,23 +209,26 @@ def _make_preset_out_bias(
def _fill_stat_with_global(

Check warning on line 209 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L209

Added line #L209 was not covered by tests
atomic_stat: Union[np.ndarray, None],
global_stat: np.ndarray,
):
):
"""This function is used to fill atomic stat with global stat.
Parameters
----------
atomic_stat : Union[np.ndarray, None]
The atomic stat.
global_stat : np.ndarray
The global stat.
if the atomic stat is None, use global stat.
if the atomic stat is None, use global stat.
if the atomic stat is not None, but has nan values (missing atypes), fill with global stat.
"""
if atomic_stat is None:
return global_stat

Check warning on line 225 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L224-L225

Added lines #L224 - L225 were not covered by tests
else:
return np.nan_to_num(np.where(np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat))
return np.nan_to_num(

Check warning on line 227 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L227

Added line #L227 was not covered by tests
np.where(
np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat
)
)


def compute_output_stats(
Expand Down Expand Up @@ -383,11 +386,13 @@ def compute_output_stats(
std_atom_e[kk] = None

Check warning on line 386 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L385-L386

Added lines #L385 - L386 were not covered by tests
# use global bias to fill missing atomic bias
if kk in bias_atom_g:
bias_atom_e[kk] = _fill_stat_with_global(bias_atom_e[kk], bias_atom_g[kk])
bias_atom_e[kk] = _fill_stat_with_global(

Check warning on line 389 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L388-L389

Added lines #L388 - L389 were not covered by tests
bias_atom_e[kk], bias_atom_g[kk]
)
std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk])

Check warning on line 392 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L392

Added line #L392 was not covered by tests
else:
raise RuntimeError("Fail to compute stat.")

Check warning on line 394 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L394

Added line #L394 was not covered by tests

if stat_file_path is not None:
_save_to_file(stat_file_path, bias_atom_e, std_atom_e)

Check warning on line 397 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L396-L397

Added lines #L396 - L397 were not covered by tests

Expand Down Expand Up @@ -566,8 +571,8 @@ def compute_output_stats_atomic(
if missing_types > 0:
nan_padding = np.empty((missing_types, bias_atom_e[kk].shape[1]))
nan_padding.fill(np.nan)
bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding],axis=0)
std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding],axis=0)
bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0)
std_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0)

Check warning on line 575 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L570-L575

Added lines #L570 - L575 were not covered by tests
else:
# this key does not have atomic labels, skip it.
continue

Check warning on line 578 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L578

Added line #L578 was not covered by tests
Expand Down
10 changes: 6 additions & 4 deletions source/tests/pt/model/test_atomic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,11 @@ def raise_error():
for kk in ["foo"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)
assert False
class TestAtomicModelStatMergeGlobalAtomic(unittest.TestCase, TestCaseSingleFrameWithNlist):


class TestAtomicModelStatMergeGlobalAtomic(
unittest.TestCase, TestCaseSingleFrameWithNlist
):
def tearDown(self):
self.tempdir.cleanup()

Expand Down Expand Up @@ -398,8 +402,6 @@ def raise_error():
ret3 = cvt_ret(ret3)
expected_ret3 = {}
# new bias [2, -5]
expected_ret3["foo"] = np.array(
[[3, 4, -2], [6, 0, 1]]
).reshape(2, 3, 1)
expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1)
for kk in ["foo"]:
np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4)

0 comments on commit bb47541

Please sign in to comment.