Skip to content

Commit

Permalink
fix bug of rcut_smth >= rcut.
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Mar 5, 2024
1 parent c31d376 commit 6899200
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
13 changes: 8 additions & 5 deletions source/tests/common/dpmodel/case_single_frame_with_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def setUp(self):
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [5, 2]
self.rcut = 0.4
self.rcut_smth = 2.2
self.sel = [16, 8]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12


class TestCaseSingleFrameWithNlist:
Expand Down Expand Up @@ -51,8 +52,10 @@ def setUp(self):
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12

# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
Expand Down
36 changes: 33 additions & 3 deletions source/tests/pt/model/test_dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,29 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret1["energy_derv_c_redu"]),
atol=self.atol,
)
ret0 = md0.forward_common(*args, do_atomic_virial=True)
ret1 = md1.forward_common(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
atol=self.atol,
)

coord_ext, atype_ext, mapping = extend_coord_with_ghosts(
Expand All @@ -106,6 +111,7 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret2["energy_derv_c_redu"]),
atol=self.atol,
)

def test_dp_consistency(self):
Expand Down Expand Up @@ -141,10 +147,12 @@ def test_dp_consistency(self):
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
ret0["energy_redu"],
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)

def test_dp_consistency_nopbc(self):
Expand Down Expand Up @@ -180,10 +188,12 @@ def test_dp_consistency_nopbc(self):
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
ret0["energy_redu"],
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)

def test_prec_consistency(self):
Expand Down Expand Up @@ -231,6 +241,7 @@ def test_prec_consistency(self):
np.testing.assert_allclose(
to_numpy_array(model_l_ret_32[ii]),
to_numpy_array(model_l_ret_64[ii]),
atol=self.atol,
)


Expand Down Expand Up @@ -263,24 +274,29 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret1["energy_derv_c_redu"]),
atol=self.atol,
)
ret0 = md0.forward_common_lower(*args, do_atomic_virial=True)
ret1 = md1.forward_common_lower(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
atol=self.atol,
)

def test_dp_consistency(self):
Expand Down Expand Up @@ -310,10 +326,12 @@ def test_dp_consistency(self):
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
ret0["energy_redu"],
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)

def test_prec_consistency(self):
Expand Down Expand Up @@ -363,6 +381,7 @@ def test_prec_consistency(self):
np.testing.assert_allclose(
to_numpy_array(model_l_ret_32[ii]),
to_numpy_array(model_l_ret_64[ii]),
atol=self.atol,
)

def test_jit(self):
Expand Down Expand Up @@ -447,7 +466,7 @@ def test_nlist_eq(self):
to_torch_tensor(self.atype_ext),
to_torch_tensor(nlist),
)
np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1))
np.testing.assert_equal(self.expected_nlist, to_numpy_array(nlist1))

def test_nlist_st(self):
# n_nnei < nnei
Expand All @@ -464,7 +483,7 @@ def test_nlist_st(self):
to_torch_tensor(self.atype_ext),
to_torch_tensor(nlist),
)
np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1))
np.testing.assert_equal(self.expected_nlist, to_numpy_array(nlist1))

def test_nlist_lt(self):
# n_nnei > nnei
Expand All @@ -481,7 +500,7 @@ def test_nlist_lt(self):
to_torch_tensor(self.atype_ext),
to_torch_tensor(nlist),
)
np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1))
np.testing.assert_equal(self.expected_nlist, to_numpy_array(nlist1))


class TestEnergyModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
Expand Down Expand Up @@ -511,24 +530,29 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["atom_energy"]),
to_numpy_array(ret1["atom_energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["force"]),
to_numpy_array(ret1["force"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret1["virial"]),
atol=self.atol,
)
ret0 = md0.forward(*args, do_atomic_virial=True)
ret1 = md1.forward(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["atom_virial"]),
to_numpy_array(ret1["atom_virial"]),
atol=self.atol,
)
coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
to_torch_tensor(self.coord),
Expand All @@ -545,6 +569,7 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret2["virial"]),
atol=self.atol,
)


Expand Down Expand Up @@ -577,24 +602,29 @@ def test_self_consistency(self):
np.testing.assert_allclose(
to_numpy_array(ret0["atom_energy"]),
to_numpy_array(ret1["atom_energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["extended_force"]),
to_numpy_array(ret1["extended_force"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["virial"]),
to_numpy_array(ret1["virial"]),
atol=self.atol,
)
ret0 = md0.forward_lower(*args, do_atomic_virial=True)
ret1 = md1.forward_lower(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["extended_virial"]),
to_numpy_array(ret1["extended_virial"]),
atol=self.atol,
)

def test_jit(self):
Expand Down
12 changes: 7 additions & 5 deletions source/tests/pt/model/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def setUp(self):
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
Expand All @@ -61,6 +61,7 @@ def setUp(self):
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.atol = 1e-12


class TestCaseSingleFrameWithoutNlist:
Expand All @@ -79,9 +80,10 @@ def setUp(self):
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [5, 2]
self.rcut = 0.4
self.rcut_smth = 2.2
self.sel = [16, 8]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12


# to be merged with the tf test case
Expand Down

0 comments on commit 6899200

Please sign in to comment.