From 8f91aea627128baaa0d78bbd7c996a1a84ae02bd Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:20:31 +0800 Subject: [PATCH] feat: dp and pt: implement exclude types in descriptor se_a (#3280) Co-authored-by: Han Wang --- deepmd/dpmodel/descriptor/exclude_mask.py | 78 ++++++++++++++++++ deepmd/dpmodel/descriptor/se_e2_a.py | 9 ++- deepmd/pt/model/descriptor/descriptor.py | 80 ++++++++++++++++++- deepmd/pt/model/descriptor/hybrid.py | 2 +- deepmd/pt/model/descriptor/repformers.py | 2 +- deepmd/pt/model/descriptor/se_a.py | 32 +++++--- deepmd/pt/model/descriptor/se_atten.py | 2 +- deepmd/tf/descriptor/se_a.py | 3 +- .../common/dpmodel/test_exclusion_mask.py | 34 ++++++++ .../consistent/descriptor/test_se_e2_a.py | 4 +- source/tests/pt/model/test_exclusion_mask.py | 45 +++++++++++ source/tests/pt/model/test_se_e2_a.py | 4 +- 12 files changed, 275 insertions(+), 20 deletions(-) create mode 100644 deepmd/dpmodel/descriptor/exclude_mask.py create mode 100644 source/tests/common/dpmodel/test_exclusion_mask.py create mode 100644 source/tests/pt/model/test_exclusion_mask.py diff --git a/deepmd/dpmodel/descriptor/exclude_mask.py b/deepmd/dpmodel/descriptor/exclude_mask.py new file mode 100644 index 0000000000..ee3edba434 --- /dev/null +++ b/deepmd/dpmodel/descriptor/exclude_mask.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Tuple, +) + +import numpy as np + + +class ExcludeMask: + """Computes the atom type exclusion mask.""" + + def __init__( + self, + ntypes: int, + exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.ntypes = ntypes + self.exclude_types = set() + for tt in exclude_types: + assert len(tt) == 2 + self.exclude_types.add((tt[0], tt[1])) + self.exclude_types.add((tt[1], tt[0])) + # ntypes + 1 for nlist masks + self.type_mask = np.array( + [ + [ + 1 if (tt_i, tt_j) not in self.exclude_types else 0 + for tt_i in range(ntypes + 1) + ] + for tt_j in range(ntypes + 1) + ], + dtype=np.int32, + ) + # (ntypes+1 x ntypes+1) + self.type_mask = self.type_mask.reshape([-1]) + + def build_type_exclude_mask( + self, + nlist: np.ndarray, + atype_ext: np.ndarray, + ): + """Compute type exclusion mask. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + atype_ext + The extended aotm types. shape: nf x nall + + Returns + ------- + mask + The type exclusion mask of shape: nf x nloc x nnei. + Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, + otherwise being 1. + + """ + if len(self.exclude_types) == 0: + # safely return 1 if nothing is excluded. + return np.ones_like(nlist, dtype=np.int32) + nf, nloc, nnei = nlist.shape + nall = atype_ext.shape[1] + # add virtual atom of type ntypes. nf x nall+1 + ae = np.concatenate( + [atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 + ) + type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1) + # nf x nloc x nnei + index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei) + type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + type_ij = type_i[:, :, None] + type_j + # nf x (nloc x nnei) + type_ij = type_ij.reshape(nf, nloc * nnei) + mask = self.type_mask[type_ij].reshape(nf, nloc, nnei) + return mask diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 78ff83a056..26258f4ac7 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -27,6 +27,9 @@ from .base_descriptor import ( BaseDescriptor, ) +from .exclude_mask import ( + ExcludeMask, +) class DescrptSeA(NativeOP, BaseDescriptor): @@ -140,8 +143,6 @@ def __init__( ## seed, uniform_seed, multi_task, not included. if not type_one_side: raise NotImplementedError("type_one_side == False not implemented") - if exclude_types != []: - raise NotImplementedError("exclude_types is not implemented") if spin is not None: raise NotImplementedError("spin is not implemented") @@ -159,6 +160,7 @@ def __init__( self.activation_function = activation_function self.precision = precision self.spin = spin + self.emask = ExcludeMask(self.ntypes, self.exclude_types) in_dim = 1 # not considiering type embedding self.embeddings = NetworkCollection( @@ -292,8 +294,11 @@ def call( ng = self.neuron[-1] gr = np.zeros([nf, nloc, ng, 4]) + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) for tt in range(self.ntypes): + mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] tr = rr[:, :, sec[tt] : sec[tt + 1], :] + tr = tr * mm[:, :, :, None] ss = tr[..., 0:1] gg = self.cal_g(ss, tt) # nf x nloc x ng x 4 diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 177f30d241..63dbe0eb19 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -8,6 +8,8 @@ Callable, List, Optional, + Set, + Tuple, Union, ) @@ -20,6 +22,9 @@ from deepmd.pt.utils.plugin import ( Plugin, ) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) from .base_descriptor import ( BaseDescriptor, @@ -206,6 +211,32 @@ class DescriptorBlock(torch.nn.Module, ABC): __plugins = Plugin() local_cluster = False + def __init__( + self, + ntypes: int, + exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + _exclude_types: Set[Tuple[int, int]] = set() + for tt in exclude_types: + assert len(tt) == 2 + _exclude_types.add((tt[0], tt[1])) + _exclude_types.add((tt[1], tt[0])) + # ntypes + 1 for nlist masks + self.type_mask = np.array( + [ + [ + 1 if (tt_i, tt_j) not in _exclude_types else 0 + for tt_i in range(ntypes + 1) + ] + for tt_j in range(ntypes + 1) + ], + dtype=np.int32, + ) + # (ntypes+1 x ntypes+1) + self.type_mask = to_torch_tensor(self.type_mask).view([-1]) + self.no_exclusion = len(_exclude_types) == 0 + @staticmethod def register(key: str) -> Callable: """Register a DescriptorBlock plugin. @@ -332,7 +363,54 @@ def forward( mapping: Optional[torch.Tensor] = None, ): """Calculate DescriptorBlock.""" - raise NotImplementedError + pass + + # may have a better place for this method... + def build_type_exclude_mask( + self, + nlist: torch.Tensor, + atype_ext: torch.Tensor, + ) -> torch.Tensor: + """Compute type exclusion mask. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + atype_ext + The extended aotm types. shape: nf x nall + + Returns + ------- + mask + The type exclusion mask of shape: nf x nloc x nnei. + Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, + otherwise being 1. + + """ + if self.no_exclusion: + # safely return 1 if nothing is excluded. + return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device) + nf, nloc, nnei = nlist.shape + nall = atype_ext.shape[1] + # add virtual atom of type ntypes. nf x nall+1 + ae = torch.cat( + [ + atype_ext, + self.get_ntypes() + * torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device), + ], + dim=-1, + ) + type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.get_ntypes() + 1) + # nf x nloc x nnei + index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei) + type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei) + type_ij = type_i[:, :, None] + type_j + # nf x (nloc x nnei) + type_ij = type_ij.view(nf, nloc * nnei) + mask = self.type_mask[type_ij].view(nf, nloc, nnei) + return mask def compute_std(sumv2, sumv, sumn, rcut_r): diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index c5c08c760d..511ac5e79b 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -32,7 +32,7 @@ def __init__( - descriptor_list: list of descriptors. - descriptor_param: descriptor configs. """ - super().__init__() + super().__init__(ntypes) supported_descrpt = ["se_atten", "se_uni"] descriptor_list = [] for descriptor_param_item in list: diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 26467124b8..de2b5f3565 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -89,7 +89,7 @@ def __init__( whether or not add an type embedding to seq_input. If no seq_input is given, it has no effect. """ - super().__init__() + super().__init__(ntypes) del type self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index c722c2dc02..da391d3255 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -3,6 +3,7 @@ ClassVar, List, Optional, + Tuple, ) import numpy as np @@ -55,6 +56,7 @@ def __init__( activation_function: str = "tanh", precision: str = "float64", resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], old_impl: bool = False, **kwargs, ): @@ -63,13 +65,14 @@ def __init__( rcut, rcut_smth, sel, - neuron, - axis_neuron, - set_davg_zero, - activation_function, - precision, - resnet_dt, - old_impl, + neuron=neuron, + axis_neuron=axis_neuron, + set_davg_zero=set_davg_zero, + activation_function=activation_function, + precision=precision, + resnet_dt=resnet_dt, + exclude_types=exclude_types, + old_impl=old_impl, **kwargs, ) @@ -212,6 +215,7 @@ def serialize(self) -> dict: "precision": RESERVED_PRECISON_DICT[obj.prec], "embeddings": obj.filter_layers.serialize(), "env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(), + "exclude_types": obj.exclude_types, "@variables": { "davg": obj["davg"].detach().cpu().numpy(), "dstd": obj["dstd"].detach().cpu().numpy(), @@ -219,7 +223,6 @@ def serialize(self) -> dict: ## to be updated when the options are supported. "trainable": True, "type_one_side": True, - "exclude_types": [], "spin": None, } @@ -256,6 +259,7 @@ def __init__( activation_function: str = "tanh", precision: str = "float64", resnet_dt: bool = False, + exclude_types: List[Tuple[int, int]] = [], old_impl: bool = False, **kwargs, ): @@ -268,7 +272,7 @@ def __init__( - filter_neuron: Number of neurons in each hidden layers of the embedding net. - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. """ - super().__init__() + super().__init__(len(sel), exclude_types=exclude_types) self.rcut = rcut self.rcut_smth = rcut_smth self.neuron = neuron @@ -280,8 +284,9 @@ def __init__( self.prec = PRECISION_DICT[self.precision] self.resnet_dt = resnet_dt self.old_impl = old_impl - + self.exclude_types = exclude_types self.ntypes = len(sel) + self.sel = sel self.sec = torch.tensor( np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE @@ -522,9 +527,16 @@ def forward( xyz_scatter = torch.zeros( [nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE ) + # nfnl x nnei + exclude_mask = self.build_type_exclude_mask(nlist, extended_atype).view( + nfnl, -1 + ) for ii, ll in enumerate(self.filter_layers.networks): + # nfnl x nt + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 4 rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] + rr = rr * mm[:, :, None] ss = rr[:, :, :1] # nfnl x nt x ng gg = ll.forward(ss) diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d4dc0cd054..e1c9942d92 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -64,7 +64,7 @@ def __init__( - filter_neuron: Number of neurons in each hidden layers of the embedding net. - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. """ - super().__init__() + super().__init__(ntypes) del type self.rcut = rcut self.rcut_smth = rcut_smth diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 986328479b..0e0cb664a4 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -201,6 +201,7 @@ def __init__( self.activation_function_name = activation_function self.filter_precision = get_precision(precision) self.filter_np_precision = get_np_precision(precision) + self.orig_exclude_types = exclude_types self.exclude_types = set() for tt in exclude_types: assert len(tt) == 2 @@ -1425,7 +1426,7 @@ def serialize(self, suffix: str = "") -> dict: "resnet_dt": self.filter_resnet_dt, "trainable": self.trainable, "type_one_side": self.type_one_side, - "exclude_types": list(self.exclude_types), + "exclude_types": list(self.orig_exclude_types), "set_davg_zero": self.set_davg_zero, "activation_function": self.activation_function_name, "precision": self.filter_precision.name, diff --git a/source/tests/common/dpmodel/test_exclusion_mask.py b/source/tests/common/dpmodel/test_exclusion_mask.py new file mode 100644 index 0000000000..dc59c57776 --- /dev/null +++ b/source/tests/common/dpmodel/test_exclusion_mask.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor.exclude_mask import ( + ExcludeMask, +) + +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +# to be merged with the tf test case +class TestExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_build_type_exclude_mask(self): + exclude_types = [[0, 1]] + expected_mask = np.array( + [ + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1], + ] + ).reshape(self.nf, self.nloc, sum(self.sel)) + des = ExcludeMask(self.nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask( + self.nlist, + self.atype_ext, + ) + np.testing.assert_equal(mask, expected_mask) diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index a694a2a20c..a1f829aeea 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -67,7 +67,7 @@ def skip_pt(self) -> bool: type_one_side, excluded_types, ) = self.param - return not type_one_side or excluded_types != [] or CommonTest.skip_pt + return not type_one_side or CommonTest.skip_pt @property def skip_dp(self) -> bool: @@ -76,7 +76,7 @@ def skip_dp(self) -> bool: type_one_side, excluded_types, ) = self.param - return not type_one_side or excluded_types != [] or CommonTest.skip_dp + return not type_one_side or CommonTest.skip_dp tf_class = DescrptSeATF dp_class = DescrptSeADP diff --git a/source/tests/pt/model/test_exclusion_mask.py b/source/tests/pt/model/test_exclusion_mask.py new file mode 100644 index 0000000000..d624a8c178 --- /dev/null +++ b/source/tests/pt/model/test_exclusion_mask.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.pt.model.descriptor.se_a import ( + DescrptBlockSeA, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +from .test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +# to be merged with the tf test case +class TestExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self): + TestCaseSingleFrameWithNlist.setUp(self) + + def test_build_type_exclude_mask(self): + exclude_types = [[0, 1]] + expected_mask = np.array( + [ + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1], + ] + ).reshape(self.nf, self.nloc, sum(self.sel)) + des = DescrptBlockSeA( + self.rcut, self.rcut_smth, self.sel, exclude_types=exclude_types + ).to(env.DEVICE) + mask = des.build_type_exclude_mask( + to_torch_tensor(self.nlist), + to_torch_tensor(self.atype_ext), + ) + np.testing.assert_equal(to_numpy_array(mask), expected_mask) diff --git a/source/tests/pt/model/test_se_e2_a.py b/source/tests/pt/model/test_se_e2_a.py index 520fdfcdfa..bb15bb423d 100644 --- a/source/tests/pt/model/test_se_e2_a.py +++ b/source/tests/pt/model/test_se_e2_a.py @@ -40,9 +40,10 @@ def test_consistency( dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( + for idt, prec, em in itertools.product( [False, True], ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -55,6 +56,7 @@ def test_consistency( precision=prec, resnet_dt=idt, old_impl=False, + exclude_mask=em, ).to(env.DEVICE) dd0.sea.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE) dd0.sea.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)