Skip to content

Commit

Permalink
pt: change the virial output dim to 9
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 27, 2024
1 parent 2631ce2 commit 0bc2279
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
24 changes: 14 additions & 10 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def task_deriv_one(
if do_atomic_virial:
extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy)
extended_virial = extended_virial + extended_virial_corr
# to [...,3,3] -> [...,9]
extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005

Check warning on line 74 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L74

Added line #L74 was not covered by tests
return extended_force, extended_virial


Expand Down Expand Up @@ -106,18 +108,18 @@ def take_deriv(
split_svv1 = torch.split(svv1, [1] * size, dim=-1)
split_ff, split_avir = [], []
for vvi, svvi in zip(split_vv1, split_svv1):
# nf x nloc x 3, nf x nloc x 3 x 3
# nf x nloc x 3, nf x nloc x 9
ffi, aviri = task_deriv_one(
vvi, svvi, coord_ext, do_atomic_virial=do_atomic_virial
)
# nf x nloc x 1 x 3, nf x nloc x 1 x 3 x 3
# nf x nloc x 1 x 3, nf x nloc x 1 x 9
ffi = ffi.unsqueeze(-2)
aviri = aviri.unsqueeze(-3)
aviri = aviri.unsqueeze(-2)

Check warning on line 117 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L117

Added line #L117 was not covered by tests
split_ff.append(ffi)
split_avir.append(aviri)
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 3 x 3
# nf x nloc x v_dim x 3, nf x nloc x v_dim x 9
ff = torch.concat(split_ff, dim=-2)
avir = torch.concat(split_avir, dim=-3)
avir = torch.concat(split_avir, dim=-2)

Check warning on line 122 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L122

Added line #L122 was not covered by tests
return ff, avir


Expand Down Expand Up @@ -185,21 +187,23 @@ def communicate_extended_output(
force = torch.zeros(
vldims + derv_r_ext_dims, dtype=vv.dtype, device=vv.device
)
# nf x nloc x 1 x 3
# nf x nloc x nvar x 3
new_ret[kk_derv_r] = torch.scatter_reduce(
force,
1,
index=mapping,
src=model_ret[kk_derv_r],
reduce="sum",
)
mapping = mapping.unsqueeze(-1).expand(
[-1] * (len(mldims) + len(derv_r_ext_dims)) + [3]
derv_c_ext_dims = list(vdef.shape) + [9] # noqa:RUF005

Check warning on line 198 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L198

Added line #L198 was not covered by tests
# nf x nloc x nvar x 3 -> nf x nloc x nvar x 9
mapping = torch.tile(

Check warning on line 200 in deepmd/pt/model/model/transform_output.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/transform_output.py#L200

Added line #L200 was not covered by tests
mapping, [1] * (len(mldims) + len(vdef.shape)) + [3]
)
virial = torch.zeros(
vldims + derv_r_ext_dims + [3], dtype=vv.dtype, device=vv.device
vldims + derv_c_ext_dims, dtype=vv.dtype, device=vv.device
)
# nf x nloc x 1 x 3
# nf x nloc x nvar x 9
new_ret[kk_derv_c] = torch.scatter_reduce(
virial,
1,
Expand Down
8 changes: 5 additions & 3 deletions source/tests/pt/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,11 @@ def np_infer(
def ff(bb):
return np_infer(bb)["energy"]

fdv = -(
finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell
).squeeze()
fdv = (
-(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell)
.squeeze()
.reshape(9)
)
rfv = np_infer(cell)["virial"]
np.testing.assert_almost_equal(fdv, rfv, decimal=places)

Expand Down
8 changes: 4 additions & 4 deletions source/tests/pt/test_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test(
)
if not hasattr(self, "test_virial") or self.test_virial:
torch.testing.assert_close(
torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)),
ret1["virial"],
torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)),
ret1["virial"].view([3, 3]),
rtol=prec,
atol=prec,
)
Expand Down Expand Up @@ -102,8 +102,8 @@ def test(
)
if not hasattr(self, "test_virial") or self.test_virial:
torch.testing.assert_close(
torch.matmul(rmat.T, torch.matmul(ret0["virial"], rmat)),
ret1["virial"],
torch.matmul(rmat.T, torch.matmul(ret0["virial"].view([3, 3]), rmat)),
ret1["virial"].view([3, 3]),
rtol=prec,
atol=prec,
)
Expand Down
5 changes: 3 additions & 2 deletions source/tests/pt/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ def test_rotation(self):
if "virial" in result1:
self.assertTrue(
torch.allclose(
result2["virial"][0],
result2["virial"][0].view([3, 3]),
torch.matmul(
torch.matmul(rotation, result1["virial"][0].T), rotation.T
torch.matmul(rotation, result1["virial"][0].view([3, 3]).T),
rotation.T,
),
)
)
Expand Down

0 comments on commit 0bc2279

Please sign in to comment.