Skip to content

Commit

Permalink
fix: correct exclude_types in descriptors (#3841)
Browse files Browse the repository at this point in the history
1. make `exclude_types` consistent with mask in nlist for all
descriptors. (bugs fixed in dpa1 and dpa2)
2. add universal tests for descriptor. (now only test_exclude_types)
3. `TestCaseSingleFrameWithNlist` in
source/tests/pt/model/test_env_mat.py will be removed in a seperate PR.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced compatibility parameter `ntypes` in several descriptor
classes for improved flexibility.
- Added new test case classes and utility functions for simulation
testing of atomic models.

- **Bug Fixes**
- Enhanced logic for exclusion masks and neighbor lists in various
descriptor methods to improve accuracy.

- **Tests**
- Added comprehensive tests for forward propagation and type exclusion
in atomic models, ensuring consistency and correctness.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jun 4, 2024
1 parent b5d9b77 commit f8bd3be
Show file tree
Hide file tree
Showing 16 changed files with 450 additions and 12 deletions.
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,10 @@ def call(
nf, nloc, nnei, _ = dmatrix.shape
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
# nfnl x nnei
nlist = nlist.reshape(nf * nloc, nnei)
nlist = np.where(exclude_mask, nlist, -1)
# nfnl x nnei x 4
dmatrix = dmatrix.reshape(nf * nloc, nnei, 4)
# nfnl x nnei x 1
Expand All @@ -824,8 +827,6 @@ def call(
nf * nloc, nnei, self.tebd_dim
)
ng = self.neuron[-1]
# nfnl x nnei
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
# nfnl x nnei x 4
rr = dmatrix.reshape(nf * nloc, nnei, 4)
rr = rr * exclude_mask[:, :, None]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def call(
mapping: Optional[np.ndarray] = None,
):
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
nlist = nlist * exclude_mask
nlist = np.where(exclude_mask, nlist, -1)
# nf x nloc x nnei x 4
dmatrix, diff, sw = self.env_mat.call(
coord_ext, atype_ext, nlist, self.mean, self.stddev
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class DescrptSeA(NativeOP, BaseDescriptor):
The precision of the embedding net parameters. Supported options are |PRECISION|
spin
The deepspin object.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.
Limitations
-----------
Expand Down Expand Up @@ -150,9 +153,11 @@ def __init__(
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
spin: Optional[Any] = None,
ntypes: Optional[int] = None, # to be compat with input
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
del ntypes
## seed, uniform_seed, not included.
if spin is not None:
raise NotImplementedError("spin is not implemented")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class DescrptSeR(NativeOP, BaseDescriptor):
The precision of the embedding net parameters. Supported options are |PRECISION|
spin
The deepspin object.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.
Limitations
-----------
Expand Down Expand Up @@ -107,9 +110,11 @@ def __init__(
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
spin: Optional[Any] = None,
ntypes: Optional[int] = None, # to be compat with input
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
del ntypes
## seed, uniform_seed, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class DescrptSeT(NativeOP, BaseDescriptor):
If the weights of embedding net are trainable.
seed : int, Optional
Random seed for initializing the network parameters.
ntypes : int
Number of element types.
Not used in this descriptor, only to be compat with input.
"""

def __init__(
Expand All @@ -94,7 +97,9 @@ def __init__(
precision: str = DEFAULT_PRECISION,
trainable: bool = True,
seed: Optional[int] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
del ntypes
self.rcut = rcut
self.rcut_smth = rcut_smth
self.sel = sel
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def forward(
atype = extended_atype[:, :nloc]
# nb x nloc x nnei
exclude_mask = self.emask(nlist, extended_atype)
nlist = nlist * exclude_mask
nlist = torch.where(exclude_mask != 0, nlist, -1)
# nb x nloc x nnei x 4, nb x nloc x nnei x 3, nb x nloc x nnei x 1
dmatrix, diff, sw = prod_env_mat(
extended_coord,
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,12 @@ def forward(
self.rcut_smth,
protection=self.env_protection,
)
# nb x nloc x nnei
exclude_mask = self.emask(nlist, extended_atype)
nlist = torch.where(exclude_mask != 0, nlist, -1)
nlist_mask = nlist != -1
nlist = torch.where(nlist == -1, 0, nlist)
sw = torch.squeeze(sw, -1)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# nf x nloc x nt -> nf x nloc x nnei x nt
atype_tebd = extended_atype_embd[:, :nloc, :]
atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1)
Expand All @@ -495,8 +496,10 @@ def forward(
atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index)
# nb x nloc x nnei x nt
atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt)
# beyond the cutoff sw should be 0.0
sw = sw.masked_fill(~nlist_mask, 0.0)
# (nb x nloc) x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nb * nloc, nnei)
exclude_mask = exclude_mask.view(nb * nloc, nnei)
if self.old_impl:
assert self.filter_layers_old is not None
dmatrix = dmatrix.view(
Expand Down
19 changes: 14 additions & 5 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,12 @@ def _pass_filter(
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
# (nframes * nloc * nnei, 1)
nei_exclude_mask = tf.slice(
tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]),
[0, 0],
[-1, 1],
)
if self.smooth:
inputs_i = tf.where(
tf.cast(mask, tf.bool),
Expand All @@ -727,15 +733,18 @@ def _pass_filter(
tf.reshape(self.avg_looked_up, [-1, 1]), [1, self.ndescrpt]
),
)
# (nframes, nloc, nnei)
self.recovered_switch *= tf.reshape(
tf.slice(
tf.reshape(tf.cast(mask, self.filter_precision), [-1, 4]),
[0, 0],
[-1, 1],
),
nei_exclude_mask,
[-1, natoms[0], self.sel_all_a[0]],
)
else:
# (nframes * nloc, 1, nnei)
self.nmask *= tf.reshape(
nei_exclude_mask,
[-1, 1, self.sel_all_a[0]],
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
inputs_i *= mask
if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor:
inputs_i = descrpt2r4(inputs_i, atype)
Expand Down
133 changes: 133 additions & 0 deletions source/tests/universal/common/cases/cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np


# originally copied from source/tests/pt/model/test_env_mat.py
class TestCaseSingleFrameWithNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0, -1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
self.mapping = np.concatenate(
[self.mapping, self.mapping[:, self.perm]], axis=0
)

# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.atol = 1e-12


class TestCaseSingleFrameWithNlistWithVirtual:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 4
self.nall = 5
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.sel_mix = [7]
self.natoms = [3, 3, 2, 1]
self.nlist = np.array(
[
[2, 4, -1, -1, -1, 3, -1],
[-1, -1, -1, -1, -1, -1, -1],
[0, -1, -1, -1, -1, 3, -1],
[0, 2, -1, -1, -1, -1, -1],
],
dtype=int,
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 2.2
self.rcut_smth = 0.4
# permutations
self.perm = np.array([3, 0, 1, 2, 4], dtype=np.int32)
inv_perm = np.argsort(self.perm)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
self.get_real_mapping = np.array([[0, 2, 3], [0, 1, 3]], dtype=np.int32)
self.atol = 1e-12


class TestCaseSingleFrameWithoutNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nf, self.nt = 1, 2
self.coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([1, self.nloc * 3])
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [16, 8]
self.sel_mix = [24]
self.natoms = [3, 3, 2, 1]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
11 changes: 11 additions & 0 deletions source/tests/universal/common/cases/descriptor/descriptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
DescriptorTestCase,
)


class DescriptorTest(DescriptorTestCase):
def setUp(self) -> None:
DescriptorTestCase.setUp(self)
Loading

0 comments on commit f8bd3be

Please sign in to comment.