Skip to content

Commit

Permalink
breaking: pt: unify the output of descriptors. (deepmodeling#3190)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Jan 29, 2024
1 parent 8900561 commit 1e51a88
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 89 deletions.
18 changes: 16 additions & 2 deletions deepmd/model_format/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,18 @@ def call(
Returns
-------
descriptor
The descriptor. shape: nf x nloc x ng x axis_neuron
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
# nf x nloc x nnei x 4
rr, ww = self.env_mat.call(coord_ext, atype_ext, nlist, self.davg, self.dstd)
Expand All @@ -238,15 +249,17 @@ def call(
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
gr += np.einsum("flni,flnj->flij", gg, tr)
# nf x nloc x ng x 4
gr /= self.nnei
gr1 = gr[:, :, : self.axis_neuron, :]
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
Expand All @@ -271,6 +284,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
Expand Down
35 changes: 33 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,42 @@ def forward(
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.
Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
del mapping
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
g1_ext = self.type_embedding(extended_atype)
g1_inp = g1_ext[:, :nloc, :]
g1, env_mat, diff, rot_mat, sw = self.se_atten(
g1, g2, h2, rot_mat, sw = self.se_atten(
nlist,
extended_coord,
extended_atype,
Expand All @@ -149,4 +179,5 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, env_mat, diff, rot_mat, sw

return g1, rot_mat, g2, h2, sw
32 changes: 31 additions & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,36 @@ def forward(
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
"""Compute the descriptor.
Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, mapps extended region index to local region.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
shape: nf x nloc x nnei x ng
h2
The rotationally equivariant pair-partical representation.
shape: nf x nloc x nnei x 3
sw
The smooth switch function. shape: nf x nloc x nnei
"""
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
# nlists
Expand Down Expand Up @@ -372,4 +402,4 @@ def forward(
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
return g1, g2, h2, rot_mat, sw
return g1, rot_mat, g2, h2, sw
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def forward(
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

return g1, g2, h2, rot_mat.view(-1, self.dim_emb, 3), sw
return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw

def compute_input_stats(self, merged):
"""Update mean and stddev for descriptor elements."""
Expand Down
46 changes: 39 additions & 7 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,42 @@ def get_data_process_key(cls, config):

def forward(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
coord_ext: torch.Tensor,
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
):
return self.sea.forward(nlist, extended_coord, extended_atype, None, mapping)
"""Compute the descriptor.
Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.
Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function.
"""
return self.sea.forward(nlist, coord_ext, atype_ext, None, mapping)

def set_stat_mean_and_stddev(
self,
Expand Down Expand Up @@ -389,7 +419,7 @@ def forward(
del extended_atype_embd, mapping
nloc = nlist.shape[1]
atype = extended_atype[:, :nloc]
dmatrix, diff, _ = prod_env_mat_se_a(
dmatrix, diff, sw = prod_env_mat_se_a(
extended_coord,
nlist,
atype,
Expand Down Expand Up @@ -438,12 +468,14 @@ def forward(
result = torch.matmul(
xyz_scatter_1, xyz_scatter_2
) # shape is [nframes*nall, self.filter_neuron[-1], self.axis_neuron]
result = result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron)
rot_mat = rot_mat.view([-1, nloc] + list(rot_mat.shape[1:])) # noqa:RUF005
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
None,
None,
result,
rot_mat,
None,
None,
sw,
)


Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ def forward(
self.rcut,
self.rcut_smth,
)
dmatrix = dmatrix.view(
-1, self.ndescrpt
) # shape is [nframes*nall, self.ndescrpt]
# [nfxnlocxnnei, self.ndescrpt]
dmatrix = dmatrix.view(-1, self.ndescrpt)
nlist_mask = nlist != -1
nlist[nlist == -1] = 0
sw = torch.squeeze(sw, -1)
Expand Down Expand Up @@ -328,7 +327,7 @@ def forward(
return (
result.view(-1, nloc, self.filter_neuron[-1] * self.axis_neuron),
ret.view(-1, nloc, self.nnei, self.filter_neuron[-1]),
diff,
dmatrix.view(-1, nloc, self.nnei, 4)[..., 1:],
rot_mat.view(-1, self.filter_neuron[-1], 3),
sw,
)
Expand Down
69 changes: 15 additions & 54 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Descriptor,
)
from deepmd.pt.model.task import (
DenoiseNet,
Fitting,
)

Expand Down Expand Up @@ -93,40 +92,20 @@ def __init__(
sampled=sampled,
)

# Fitting
if fitting_net:
fitting_net["type"] = fitting_net.get("type", "ener")
if self.descriptor_type not in ["se_e2_a"]:
fitting_net["ntypes"] = 1
else:
fitting_net["ntypes"] = self.descriptor.get_ntype()
fitting_net["use_tebd"] = False
fitting_net["embedding_width"] = self.descriptor.dim_out

self.grad_force = "direct" not in fitting_net["type"]
if not self.grad_force:
fitting_net["out_dim"] = self.descriptor.dim_emb
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
self.fitting_net = Fitting(**fitting_net)
fitting_net["type"] = fitting_net.get("type", "ener")
if self.descriptor_type not in ["se_e2_a"]:
fitting_net["ntypes"] = 1
else:
self.fitting_net = None
self.grad_force = False
if not self.split_nlist:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb
)
elif self.combination:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out,
self.ntypes - 1,
self.descriptor.dim_emb_list,
self.prefactor,
)
else:
self.coord_denoise_net = DenoiseNet(
self.descriptor.dim_out, self.ntypes - 1, self.descriptor.dim_emb
)
fitting_net["ntypes"] = self.descriptor.get_ntype()
fitting_net["use_tebd"] = False
fitting_net["embedding_width"] = self.descriptor.dim_out

self.grad_force = "direct" not in fitting_net["type"]
if not self.grad_force:
fitting_net["out_dim"] = self.descriptor.dim_emb
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
self.fitting_net = Fitting(**fitting_net)

def get_fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand Down Expand Up @@ -178,31 +157,13 @@ def forward_atomic(
atype = extended_atype[:, :nloc]
if self.do_grad():
extended_coord.requires_grad_(True)
descriptor, env_mat, diff, rot_mat, sw = self.descriptor(
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
assert descriptor is not None
# energy, force
if self.fitting_net is not None:
fit_ret = self.fitting_net(
descriptor, atype, atype_tebd=None, rot_mat=rot_mat
)
# denoise
else:
nlist_list = [nlist]
if not self.split_nlist:
nnei_mask = nlist != -1
elif self.combination:
nnei_mask = []
for item in nlist_list:
nnei_mask_item = item != -1
nnei_mask.append(nnei_mask_item)
else:
env_mat = env_mat[-1]
diff = diff[-1]
nnei_mask = nlist_list[-1] != -1
fit_ret = self.coord_denoise_net(env_mat, diff, nnei_mask, descriptor, sw)
fit_ret = self.fitting_net(descriptor, atype, atype_tebd=None, rot_mat=rot_mat)
return fit_ret
3 changes: 3 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ def forward(
inputs
) # Shape is [nframes, nloc, m1]
assert list(vec_out.size()) == [nframes, nloc, self.out_dim]
# (nf x nloc) x 1 x od
vec_out = vec_out.view(-1, 1, self.out_dim)
assert rot_mat is not None
# (nf x nloc) x od x 3
rot_mat = rot_mat.view(-1, self.out_dim, 3)
vec_out = (
torch.bmm(vec_out, rot_mat).squeeze(-2).view(nframes, nloc, 3)
) # Shape is [nframes, nloc, 3]
Expand Down
3 changes: 2 additions & 1 deletion source/tests/common/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,5 @@ def test_self_consistency(
em1 = DescrptSeA.deserialize(em0.serialize())
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist)
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist)
np.testing.assert_allclose(mm0, mm1)
for ii in [0, 1, 4]:
np.testing.assert_allclose(mm0[ii], mm1[ii])
2 changes: 2 additions & 0 deletions source/tests/pt/test_permutation_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test(
)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA1(unittest.TestCase, PermutationDenoiseTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa1)
Expand All @@ -74,6 +75,7 @@ def setUp(self):
self.model = get_model(model_params, sampled).to(env.DEVICE)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA2(unittest.TestCase, PermutationDenoiseTest):
def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
Expand Down
2 changes: 2 additions & 0 deletions source/tests/pt/test_rot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test(
)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA1(unittest.TestCase, RotDenoiseTest):
def setUp(self):
model_params = copy.deepcopy(model_dpa1)
Expand All @@ -105,6 +106,7 @@ def setUp(self):
self.model = get_model(model_params, sampled).to(env.DEVICE)


@unittest.skip("support of the denoise is temporally disabled")
class TestDenoiseModelDPA2(unittest.TestCase, RotDenoiseTest):
def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
Expand Down
Loading

0 comments on commit 1e51a88

Please sign in to comment.