From 0bc2279e1febb0e761e3fcfa70c98b2efd9c42c7 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 27 Jan 2024 20:51:41 +0800 Subject: [PATCH] pt: change the virial output dim to 9 --- deepmd/pt/model/model/transform_output.py | 24 +++++++++++++---------- source/tests/pt/test_autodiff.py | 8 +++++--- source/tests/pt/test_rot.py | 8 ++++---- source/tests/pt/test_rotation.py | 5 +++-- 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index 673491d788..a14518e8a0 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -70,6 +70,8 @@ def task_deriv_one( if do_atomic_virial: extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) extended_virial = extended_virial + extended_virial_corr + # to [...,3,3] -> [...,9] + extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005 return extended_force, extended_virial @@ -106,18 +108,18 @@ def take_deriv( split_svv1 = torch.split(svv1, [1] * size, dim=-1) split_ff, split_avir = [], [] for vvi, svvi in zip(split_vv1, split_svv1): - # nf x nloc x 3, nf x nloc x 3 x 3 + # nf x nloc x 3, nf x nloc x 9 ffi, aviri = task_deriv_one( vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial ) - # nf x nloc x 1 x 3, nf x nloc x 1 x 3 x 3 + # nf x nloc x 1 x 3, nf x nloc x 1 x 9 ffi = ffi.unsqueeze(-2) - aviri = aviri.unsqueeze(-3) + aviri = aviri.unsqueeze(-2) split_ff.append(ffi) split_avir.append(aviri) - # nf x nloc x v_dim x 3, nf x nloc x v_dim x 3 x 3 + # nf x nloc x v_dim x 3, nf x nloc x v_dim x 9 ff = torch.concat(split_ff, dim=-2) - avir = torch.concat(split_avir, dim=-3) + avir = torch.concat(split_avir, dim=-2) return ff, avir @@ -185,7 +187,7 @@ def communicate_extended_output( force = torch.zeros( vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device ) - # nf x nloc x 1 x 3 + # nf x nloc x nvar x 3 new_ret[kk_derv_r] = torch.scatter_reduce( force, 1, @@ -193,13 +195,15 @@ def communicate_extended_output( src=model_ret[kk_derv_r], reduce="sum", ) - mapping = mapping.unsqueeze(-1).expand( - [-1] * (len(mldims) + len(derv_r_ext_dims)) + [3] + derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005 + # nf x nloc x nvar x 3 -> nf x nloc x nvar x 9 + mapping = torch.tile( + mapping, [1] * (len(mldims) + len(vdef.shape)) + [3] ) virial = torch.zeros( - vldims + derv_r_ext_dims + [3], dtype=vv.dtype, device=vv.device + vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device ) - # nf x nloc x 1 x 3 + # nf x nloc x nvar x 9 new_ret[kk_derv_c] = torch.scatter_reduce( virial, 1, diff --git a/source/tests/pt/test_autodiff.py b/source/tests/pt/test_autodiff.py index 4f303a8bb3..8840fbdd4c 100644 --- a/source/tests/pt/test_autodiff.py +++ b/source/tests/pt/test_autodiff.py @@ -121,9 +121,11 @@ def np_infer( def ff(bb): return np_infer(bb)["energy"] - fdv = -( - finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell - ).squeeze() + fdv = ( + -(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell) + .squeeze() + .reshape(9) + ) rfv = np_infer(cell)["virial"] np.testing.assert_almost_equal(fdv, rfv, decimal=places) diff --git a/source/tests/pt/test_rot.py b/source/tests/pt/test_rot.py index b5d9d9b64b..7222fd6f69 100644 --- a/source/tests/pt/test_rot.py +++ b/source/tests/pt/test_rot.py @@ -65,8 +65,8 @@ def test( ) if not hasattr(self, "test_virial") or self.test_virial: torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), - ret1["virial"], + torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), + ret1["virial"].view([3, 3]), rtol=prec, atol=prec, ) @@ -102,8 +102,8 @@ def test( ) if not hasattr(self, "test_virial") or self.test_virial: torch.testing.assert_close( - torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)), - ret1["virial"], + torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)), + ret1["virial"].view([3, 3]), rtol=prec, atol=prec, ) diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py index 4b49377a27..58ec80e0d6 100644 --- a/source/tests/pt/test_rotation.py +++ b/source/tests/pt/test_rotation.py @@ -121,9 +121,10 @@ def test_rotation(self): if "virial" in result1: self.assertTrue( torch.allclose( - result2["virial"][0], + result2["virial"][0].view([3, 3]), torch.matmul( - torch.matmul(rotation, result1["virial"][0].T), rotation.T + torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), + rotation.T, ), ) )