diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index 87c0209169..8fe6a131ef 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -119,7 +119,7 @@ def test_has_message_passing(self) -> None: def test_forward(self) -> None: """Test forward and forward_lower.""" test_spin = getattr(self, "test_spin", False) - nf = 1 + nf = 2 natoms = 5 aprec = ( 0 @@ -127,10 +127,10 @@ def test_forward(self) -> None: else self.aprec_dict["test_forward"] ) rng = np.random.default_rng(GLOBAL_SEED) - coord = 4.0 * rng.random([natoms, 3]).reshape([nf, -1]) - atype = np.array([0, 0, 0, 1, 1], dtype=int).reshape([nf, -1]) - spin = 0.5 * rng.random([natoms, 3]).reshape([nf, -1]) - cell = 6.0 * np.eye(3).reshape([nf, 9]) + coord = 4.0 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1]) + atype = np.array([[0, 0, 0, 1, 1] * nf], dtype=int).reshape([nf, -1]) + spin = 0.5 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1]) + cell = 6.0 * np.repeat(np.eye(3)[None, ...], nf, axis=0).reshape([nf, 9]) coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( coord, atype, @@ -147,9 +147,9 @@ def test_forward(self) -> None: aparam = None fparam = None if self.module.get_dim_aparam() > 0: - aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) + aparam = rng.random([1, natoms, self.module.get_dim_aparam()]).repeat(nf, 0) if self.module.get_dim_fparam() > 0: - fparam = rng.random([nf, self.module.get_dim_fparam()]) + fparam = rng.random([1, self.module.get_dim_fparam()]).repeat(nf, 0) ret = [] ret_lower = [] for module in self.modules_to_test: @@ -183,6 +183,15 @@ def test_forward(self) -> None: ret_lower.append(module.forward_lower(**input_dict_lower)) for kk in ret[0]: + # ensure the first frame and the second frame are the same + if ret[0][kk] is not None: + np.testing.assert_allclose( + ret[0][kk][0], + ret[0][kk][1], + err_msg=f"compare {kk} between frame 0 and 1", + atol=aprec, + ) + subret = [] for rr in ret: if rr is not None: