Skip to content

Commit

Permalink
chore(tests): ensure the same result of frame 0 and 1 (#4442)
Browse files Browse the repository at this point in the history
Copied from njzjz/deepmd-gnn#27.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Bug Fixes**
- Enhanced the robustness of the testing framework to ensure consistent
model output across multiple frames of input data.
- Added assertions to validate output equivalence for the first and
second frames.

- **Tests**
- Adjusted the testing methods to accommodate changes in input
dimensionality and ensure proper validation of model behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 30, 2024
1 parent 03c6e49 commit db0a2a3
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions source/tests/universal/common/cases/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,18 @@ def test_has_message_passing(self) -> None:
def test_forward(self) -> None:
"""Test forward and forward_lower."""
test_spin = getattr(self, "test_spin", False)
nf = 1
nf = 2
natoms = 5
aprec = (
0
if self.aprec_dict.get("test_forward", None) is None
else self.aprec_dict["test_forward"]
)
rng = np.random.default_rng(GLOBAL_SEED)
coord = 4.0 * rng.random([natoms, 3]).reshape([nf, -1])
atype = np.array([0, 0, 0, 1, 1], dtype=int).reshape([nf, -1])
spin = 0.5 * rng.random([natoms, 3]).reshape([nf, -1])
cell = 6.0 * np.eye(3).reshape([nf, 9])
coord = 4.0 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1])
atype = np.array([[0, 0, 0, 1, 1] * nf], dtype=int).reshape([nf, -1])
spin = 0.5 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1])
cell = 6.0 * np.repeat(np.eye(3)[None, ...], nf, axis=0).reshape([nf, 9])
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
coord,
atype,
Expand All @@ -147,9 +147,9 @@ def test_forward(self) -> None:
aparam = None
fparam = None
if self.module.get_dim_aparam() > 0:
aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
aparam = rng.random([1, natoms, self.module.get_dim_aparam()]).repeat(nf, 0)
if self.module.get_dim_fparam() > 0:
fparam = rng.random([nf, self.module.get_dim_fparam()])
fparam = rng.random([1, self.module.get_dim_fparam()]).repeat(nf, 0)
ret = []
ret_lower = []
for module in self.modules_to_test:
Expand Down Expand Up @@ -183,6 +183,15 @@ def test_forward(self) -> None:
ret_lower.append(module.forward_lower(**input_dict_lower))

for kk in ret[0]:
# ensure the first frame and the second frame are the same
if ret[0][kk] is not None:
np.testing.assert_allclose(
ret[0][kk][0],
ret[0][kk][1],
err_msg=f"compare {kk} between frame 0 and 1",
atol=aprec,
)

subret = []
for rr in ret:
if rr is not None:
Expand Down

0 comments on commit db0a2a3

Please sign in to comment.