From f2b84ff89b98b2fd1f5546d5ba1bf22a7e270713 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Sat, 17 Feb 2024 22:30:47 +0800 Subject: [PATCH] feat: dp and pt: implement fitting exclude types (#3282) - implement fitting exclude types - pt: refactorize the pair exclusion masks as torch modules. --------- Co-authored-by: Han Wang --- deepmd/dpmodel/descriptor/se_e2_a.py | 6 +- deepmd/dpmodel/fitting/invar_fitting.py | 11 ++ deepmd/dpmodel/utils/__init__.py | 6 + .../{descriptor => utils}/exclude_mask.py | 49 ++++++- deepmd/pt/model/descriptor/descriptor.py | 78 ----------- deepmd/pt/model/descriptor/hybrid.py | 2 +- deepmd/pt/model/descriptor/repformers.py | 2 +- deepmd/pt/model/descriptor/se_a.py | 10 +- deepmd/pt/model/descriptor/se_atten.py | 2 +- deepmd/pt/model/task/ener.py | 2 + deepmd/pt/model/task/fitting.py | 16 ++- deepmd/pt/utils/__init__.py | 10 ++ deepmd/pt/utils/exclude_mask.py | 131 ++++++++++++++++++ .../common/dpmodel/test_exclusion_mask.py | 32 ++++- .../dpmodel/test_fitting_invar_fitting.py | 28 ++++ source/tests/pt/model/test_ener_fitting.py | 8 +- source/tests/pt/model/test_exclusion_mask.py | 38 +++-- 17 files changed, 320 insertions(+), 111 deletions(-) rename deepmd/dpmodel/{descriptor => utils}/exclude_mask.py (64%) create mode 100644 deepmd/pt/utils/exclude_mask.py diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 26258f4ac7..4e26afa729 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -22,14 +22,12 @@ EmbeddingNet, EnvMat, NetworkCollection, + PairExcludeMask, ) from .base_descriptor import ( BaseDescriptor, ) -from .exclude_mask import ( - ExcludeMask, -) class DescrptSeA(NativeOP, BaseDescriptor): @@ -160,7 +158,7 @@ def __init__( self.activation_function = activation_function self.precision = precision self.spin = spin - self.emask = ExcludeMask(self.ntypes, self.exclude_types) + self.emask = PairExcludeMask(self.ntypes, self.exclude_types) in_dim = 1 # not considiering type embedding self.embeddings = NetworkCollection( diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index 58607a9f26..0c2a6006cc 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -19,6 +19,7 @@ fitting_check_output, ) from deepmd.dpmodel.utils import ( + AtomExcludeMask, FittingNet, NetworkCollection, ) @@ -126,6 +127,7 @@ def __init__( use_aparam_as_mask: bool = False, spin: Any = None, distinguish_types: bool = False, + exclude_types: List[int] = [], ): # seed, uniform_seed are not included if tot_ener_zero: @@ -159,8 +161,10 @@ def __init__( self.use_aparam_as_mask = use_aparam_as_mask self.spin = spin self.distinguish_types = distinguish_types + self.exclude_types = exclude_types if self.spin is not None: raise NotImplementedError("spin is not supported") + self.emask = AtomExcludeMask(self.ntypes, exclude_types=self.exclude_types) # init constants self.bias_atom_e = np.zeros([self.ntypes, self.dim_out]) @@ -260,6 +264,7 @@ def serialize(self) -> dict: "precision": self.precision, "distinguish_types": self.distinguish_types, "nets": self.nets.serialize(), + "exclude_types": self.exclude_types, "@variables": { "bias_atom_e": self.bias_atom_e, "fparam_avg": self.fparam_avg, @@ -370,4 +375,10 @@ def call( outs = outs + atom_energy # Shape is [nframes, natoms[0], 1] else: outs = self.nets[()](xx) + self.bias_atom_e[atype] + + # nf x nloc + exclude_mask = self.emask.build_type_exclude_mask(atype) + # nf x nloc x nod + outs = outs * exclude_mask[:, :, None] + return {self.var_name: outs} diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index d3c31ae246..60a4486d52 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -2,6 +2,10 @@ from .env_mat import ( EnvMat, ) +from .exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) from .network import ( EmbeddingNet, FittingNet, @@ -53,4 +57,6 @@ "inter2phys", "phys2inter", "to_face_distance", + "AtomExcludeMask", + "PairExcludeMask", ] diff --git a/deepmd/dpmodel/descriptor/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py similarity index 64% rename from deepmd/dpmodel/descriptor/exclude_mask.py rename to deepmd/dpmodel/utils/exclude_mask.py index ee3edba434..83e3c7a363 100644 --- a/deepmd/dpmodel/descriptor/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -7,15 +7,54 @@ import numpy as np -class ExcludeMask: - """Computes the atom type exclusion mask.""" +class AtomExcludeMask: + """Computes the type exclusion mask for atoms.""" + + def __init__( + self, + ntypes: int, + exclude_types: List[int] = [], + ): + self.ntypes = ntypes + self.exclude_types = exclude_types + self.type_mask = np.array( + [1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)], + dtype=np.int32, + ) + # (ntypes) + self.type_mask = self.type_mask.reshape([-1]) + + def build_type_exclude_mask( + self, + atype: np.ndarray, + ): + """Compute type exclusion mask for atoms. + + Parameters + ---------- + atype + The extended aotm types. shape: nf x natom + + Returns + ------- + mask + The type exclusion mask for atoms. shape: nf x natom + Element [ff,ii] being 0 if type(ii) is excluded, + otherwise being 1. + + """ + nf, natom = atype.shape + return self.type_mask[atype].reshape(nf, natom) + + +class PairExcludeMask: + """Computes the type exclusion mask for atom pairs.""" def __init__( self, ntypes: int, exclude_types: List[Tuple[int, int]] = [], ): - super().__init__() self.ntypes = ntypes self.exclude_types = set() for tt in exclude_types: @@ -41,7 +80,7 @@ def build_type_exclude_mask( nlist: np.ndarray, atype_ext: np.ndarray, ): - """Compute type exclusion mask. + """Compute type exclusion mask for atom pairs. Parameters ---------- @@ -53,7 +92,7 @@ def build_type_exclude_mask( Returns ------- mask - The type exclusion mask of shape: nf x nloc x nnei. + The type exclusion mask for pair atoms of shape: nf x nloc x nnei. Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, otherwise being 1. diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 63dbe0eb19..bd6839834e 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -8,8 +8,6 @@ Callable, List, Optional, - Set, - Tuple, Union, ) @@ -22,9 +20,6 @@ from deepmd.pt.utils.plugin import ( Plugin, ) -from deepmd.pt.utils.utils import ( - to_torch_tensor, -) from .base_descriptor import ( BaseDescriptor, @@ -211,32 +206,6 @@ class DescriptorBlock(torch.nn.Module, ABC): __plugins = Plugin() local_cluster = False - def __init__( - self, - ntypes: int, - exclude_types: List[Tuple[int, int]] = [], - ): - super().__init__() - _exclude_types: Set[Tuple[int, int]] = set() - for tt in exclude_types: - assert len(tt) == 2 - _exclude_types.add((tt[0], tt[1])) - _exclude_types.add((tt[1], tt[0])) - # ntypes + 1 for nlist masks - self.type_mask = np.array( - [ - [ - 1 if (tt_i, tt_j) not in _exclude_types else 0 - for tt_i in range(ntypes + 1) - ] - for tt_j in range(ntypes + 1) - ], - dtype=np.int32, - ) - # (ntypes+1 x ntypes+1) - self.type_mask = to_torch_tensor(self.type_mask).view([-1]) - self.no_exclusion = len(_exclude_types) == 0 - @staticmethod def register(key: str) -> Callable: """Register a DescriptorBlock plugin. @@ -365,53 +334,6 @@ def forward( """Calculate DescriptorBlock.""" pass - # may have a better place for this method... - def build_type_exclude_mask( - self, - nlist: torch.Tensor, - atype_ext: torch.Tensor, - ) -> torch.Tensor: - """Compute type exclusion mask. - - Parameters - ---------- - nlist - The neighbor list. shape: nf x nloc x nnei - atype_ext - The extended aotm types. shape: nf x nall - - Returns - ------- - mask - The type exclusion mask of shape: nf x nloc x nnei. - Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, - otherwise being 1. - - """ - if self.no_exclusion: - # safely return 1 if nothing is excluded. - return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device) - nf, nloc, nnei = nlist.shape - nall = atype_ext.shape[1] - # add virtual atom of type ntypes. nf x nall+1 - ae = torch.cat( - [ - atype_ext, - self.get_ntypes() - * torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device), - ], - dim=-1, - ) - type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.get_ntypes() + 1) - # nf x nloc x nnei - index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei) - type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei) - type_ij = type_i[:, :, None] + type_j - # nf x (nloc x nnei) - type_ij = type_ij.view(nf, nloc * nnei) - mask = self.type_mask[type_ij].view(nf, nloc, nnei) - return mask - def compute_std(sumv2, sumv, sumn, rcut_r): """Compute standard deviation.""" diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 511ac5e79b..c5c08c760d 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -32,7 +32,7 @@ def __init__( - descriptor_list: list of descriptors. - descriptor_param: descriptor configs. """ - super().__init__(ntypes) + super().__init__() supported_descrpt = ["se_atten", "se_uni"] descriptor_list = [] for descriptor_param_item in list: diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index de2b5f3565..26467124b8 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -89,7 +89,7 @@ def __init__( whether or not add an type embedding to seq_input. If no seq_input is given, it has no effect. """ - super().__init__(ntypes) + super().__init__() del type self.epsilon = 1e-4 # protection of 1./nnei self.rcut = rcut diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index da391d3255..c086fe1cc2 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -38,6 +38,9 @@ from deepmd.pt.model.network.network import ( TypeFilter, ) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) from deepmd.pt.utils.nlist import ( extend_input_and_build_neighbor_list, ) @@ -272,7 +275,7 @@ def __init__( - filter_neuron: Number of neurons in each hidden layers of the embedding net. - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. """ - super().__init__(len(sel), exclude_types=exclude_types) + super().__init__() self.rcut = rcut self.rcut_smth = rcut_smth self.neuron = neuron @@ -286,6 +289,7 @@ def __init__( self.old_impl = old_impl self.exclude_types = exclude_types self.ntypes = len(sel) + self.emask = PairExcludeMask(len(sel), exclude_types=exclude_types) self.sel = sel self.sec = torch.tensor( @@ -528,9 +532,7 @@ def forward( [nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE ) # nfnl x nnei - exclude_mask = self.build_type_exclude_mask(nlist, extended_atype).view( - nfnl, -1 - ) + exclude_mask = self.emask(nlist, extended_atype).view(nfnl, -1) for ii, ll in enumerate(self.filter_layers.networks): # nfnl x nt mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index e1c9942d92..d4dc0cd054 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -64,7 +64,7 @@ def __init__( - filter_neuron: Number of neurons in each hidden layers of the embedding net. - axis_neuron: Number of columns of the sub-matrix of the embedding matrix. """ - super().__init__(ntypes) + super().__init__() del type self.rcut = rcut self.rcut_smth = rcut_smth diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index f1dad4c58d..b6ca12b9d8 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -89,6 +89,7 @@ def __init__( distinguish_types: bool = False, rcond: Optional[float] = None, seed: Optional[int] = None, + exclude_types: List[int] = [], **kwargs, ): super().__init__( @@ -106,6 +107,7 @@ def __init__( distinguish_types=distinguish_types, rcond=rcond, seed=seed, + exclude_types=exclude_types, **kwargs, ) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index b2d8c875ce..db8daff802 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -39,6 +39,9 @@ DEVICE, PRECISION_DICT, ) +from deepmd.pt.utils.exclude_mask import ( + AtomExcludeMask, +) from deepmd.pt.utils.plugin import ( Plugin, ) @@ -396,6 +399,7 @@ def __init__( distinguish_types: bool = False, rcond: Optional[float] = None, seed: Optional[int] = None, + exclude_types: List[int] = [], **kwargs, ): super().__init__() @@ -413,6 +417,9 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.rcond = rcond + self.exclude_types = exclude_types + + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) # init constants if bias_atom_e is None: @@ -504,6 +511,7 @@ def serialize(self) -> dict: "distinguish_types": self.distinguish_types, "nets": self.filter_layers.serialize(), "rcond": self.rcond, + "exclude_types": self.exclude_types, "@variables": { "bias_atom_e": to_numpy_array(self.bias_atom_e), "fparam_avg": to_numpy_array(self.fparam_avg), @@ -511,7 +519,6 @@ def serialize(self) -> dict: "aparam_avg": to_numpy_array(self.aparam_avg), "aparam_inv_std": to_numpy_array(self.aparam_inv_std), }, - # "rcond": self.rcond , # "tot_ener_zero": self.tot_ener_zero , # "trainable": self.trainable , # "atom_ener": self.atom_ener , @@ -657,7 +664,6 @@ def _forward_common( atom_property = atom_property + self.bias_atom_e[type_i] atom_property = atom_property * mask.unsqueeze(-1) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} else: if self.use_tebd: atom_property = ( @@ -673,4 +679,8 @@ def _forward_common( atom_property = atom_property + self.bias_atom_e[type_i] atom_property = atom_property * mask outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} + # nf x nloc + mask = self.emask(atype) + # nf x nloc x nod + outs = outs * mask[:, :, None] + return {self.var_name: outs.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/pt/utils/__init__.py b/deepmd/pt/utils/__init__.py index 6ceb116d85..7e1043eda4 100644 --- a/deepmd/pt/utils/__init__.py +++ b/deepmd/pt/utils/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from .exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +__all__ = [ + "PairExcludeMask", + "AtomExcludeMask", +] diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py new file mode 100644 index 0000000000..74b1d8dc41 --- /dev/null +++ b/deepmd/pt/utils/exclude_mask.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + List, + Set, + Tuple, +) + +import numpy as np +import torch + +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) + + +class AtomExcludeMask(torch.nn.Module): + """Computes the type exclusion mask for atoms.""" + + def __init__( + self, + ntypes: int, + exclude_types: List[int] = [], + ): + super().__init__() + self.ntypes = ntypes + self.exclude_types = exclude_types + self.type_mask = np.array( + [1 if tt_i not in self.exclude_types else 0 for tt_i in range(ntypes)], + dtype=np.int32, + ) + self.type_mask = to_torch_tensor(self.type_mask).view([-1]) + + def forward( + self, + atype: torch.Tensor, + ) -> torch.Tensor: + """Compute type exclusion mask for atoms. + + Parameters + ---------- + atype + The extended aotm types. shape: nf x natom + + Returns + ------- + mask + The type exclusion mask for atoms. shape: nf x natom + Element [ff,ii] being 0 if type(ii) is excluded, + otherwise being 1. + + """ + nf, natom = atype.shape + return self.type_mask[atype].view(nf, natom) + + +class PairExcludeMask(torch.nn.Module): + """Computes the type exclusion mask for atom pairs.""" + + def __init__( + self, + ntypes: int, + exclude_types: List[Tuple[int, int]] = [], + ): + super().__init__() + self.ntypes = ntypes + self._exclude_types: Set[Tuple[int, int]] = set() + for tt in exclude_types: + assert len(tt) == 2 + self._exclude_types.add((tt[0], tt[1])) + self._exclude_types.add((tt[1], tt[0])) + # ntypes + 1 for nlist masks + self.type_mask = np.array( + [ + [ + 1 if (tt_i, tt_j) not in self._exclude_types else 0 + for tt_i in range(ntypes + 1) + ] + for tt_j in range(ntypes + 1) + ], + dtype=np.int32, + ) + # (ntypes+1 x ntypes+1) + self.type_mask = to_torch_tensor(self.type_mask).view([-1]) + self.no_exclusion = len(self._exclude_types) == 0 + + # may have a better place for this method... + def forward( + self, + nlist: torch.Tensor, + atype_ext: torch.Tensor, + ) -> torch.Tensor: + """Compute type exclusion mask. + + Parameters + ---------- + nlist + The neighbor list. shape: nf x nloc x nnei + atype_ext + The extended aotm types. shape: nf x nall + + Returns + ------- + mask + The type exclusion mask of shape: nf x nloc x nnei. + Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded, + otherwise being 1. + + """ + if self.no_exclusion: + # safely return 1 if nothing is excluded. + return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device) + nf, nloc, nnei = nlist.shape + nall = atype_ext.shape[1] + # add virtual atom of type ntypes. nf x nall+1 + ae = torch.cat( + [ + atype_ext, + self.ntypes + * torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device), + ], + dim=-1, + ) + type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.ntypes + 1) + # nf x nloc x nnei + index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei) + type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei) + type_ij = type_i[:, :, None] + type_j + # nf x (nloc x nnei) + type_ij = type_ij.view(nf, nloc * nnei) + mask = self.type_mask[type_ij].view(nf, nloc, nnei) + return mask diff --git a/source/tests/common/dpmodel/test_exclusion_mask.py b/source/tests/common/dpmodel/test_exclusion_mask.py index dc59c57776..89727ec6c3 100644 --- a/source/tests/common/dpmodel/test_exclusion_mask.py +++ b/source/tests/common/dpmodel/test_exclusion_mask.py @@ -3,8 +3,9 @@ import numpy as np -from deepmd.dpmodel.descriptor.exclude_mask import ( - ExcludeMask, +from deepmd.dpmodel.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, ) from .case_single_frame_with_nlist import ( @@ -12,8 +13,31 @@ ) +class TestAtomExcludeMask(unittest.TestCase): + def test_build_type_exclude_mask(self): + nf = 2 + nt = 3 + exclude_types = [0, 2] + atype = np.array( + [ + [0, 2, 1, 2, 0, 1, 0], + [1, 2, 0, 0, 2, 2, 1], + ], + dtype=np.int32, + ).reshape([nf, -1]) + expected_mask = np.array( + [ + [0, 0, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 1], + ] + ).reshape([nf, -1]) + des = AtomExcludeMask(nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask(atype) + np.testing.assert_equal(mask, expected_mask) + + # to be merged with the tf test case -class TestExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): +class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) @@ -26,7 +50,7 @@ def test_build_type_exclude_mask(self): [0, 0, 1, 1, 1, 1, 1], ] ).reshape(self.nf, self.nloc, sum(self.sel)) - des = ExcludeMask(self.nt, exclude_types=exclude_types) + des = PairExcludeMask(self.nt, exclude_types=exclude_types) mask = des.build_type_exclude_mask( self.nlist, self.atype_ext, diff --git a/source/tests/common/dpmodel/test_fitting_invar_fitting.py b/source/tests/common/dpmodel/test_fitting_invar_fitting.py index ea70e98f7c..77d3b429ec 100644 --- a/source/tests/common/dpmodel/test_fitting_invar_fitting.py +++ b/source/tests/common/dpmodel/test_fitting_invar_fitting.py @@ -34,11 +34,13 @@ def test_self_consistency( od, nfp, nap, + et, ) in itertools.product( [True, False], [1, 2], [0, 3], [0, 4], + [[], [0], [1]], ): ifn0 = InvarFitting( "energy", @@ -48,6 +50,7 @@ def test_self_consistency( numb_fparam=nfp, numb_aparam=nap, distinguish_types=distinguish_types, + exclude_types=et, ) ifn1 = InvarFitting.deserialize(ifn0.serialize()) if nfp > 0: @@ -62,6 +65,31 @@ def test_self_consistency( ret1 = ifn1(dd[0], atype, fparam=ifp, aparam=iap) np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + def test_mask(self): + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + od = 2 + distinguish_types = False + # exclude type 1 + et = [1] + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + distinguish_types=distinguish_types, + exclude_types=et, + ) + ret0 = ifn0(dd[0], atype) + # atom index 2 is of type 1 that is excluded + zero_idx = 2 + np.testing.assert_allclose( + ret0["energy"][:, zero_idx, :], + np.zeros_like(ret0["energy"][:, zero_idx, :]), + ) + def test_self_exception( self, ): diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index 42aeeff16a..9e5ec0b903 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -44,11 +44,12 @@ def test_consistency( ) atype = torch.tensor(self.atype_ext[:, :nloc], dtype=int, device=env.DEVICE) - for od, distinguish_types, nfp, nap in itertools.product( + for od, distinguish_types, nfp, nap, et in itertools.product( [1, 3], [True, False], [0, 3], [0, 4], + [[], [0], [1]], ): ft0 = InvarFitting( "foo", @@ -58,6 +59,7 @@ def test_consistency( numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), + exclude_types=et, ).to(env.DEVICE) ft1 = DPInvarFitting.deserialize(ft0.serialize()) ft2 = InvarFitting.deserialize(ft0.serialize()) @@ -144,11 +146,12 @@ def test_new_old( def test_jit( self, ): - for od, distinguish_types, nfp, nap in itertools.product( + for od, distinguish_types, nfp, nap, et in itertools.product( [1, 3], [True, False], [0, 3], [0, 4], + [[], [0]], ): ft0 = InvarFitting( "foo", @@ -158,6 +161,7 @@ def test_jit( numb_fparam=nfp, numb_aparam=nap, use_tebd=(not distinguish_types), + exclude_types=et, ).to(env.DEVICE) torch.jit.script(ft0) diff --git a/source/tests/pt/model/test_exclusion_mask.py b/source/tests/pt/model/test_exclusion_mask.py index d624a8c178..18ab56be49 100644 --- a/source/tests/pt/model/test_exclusion_mask.py +++ b/source/tests/pt/model/test_exclusion_mask.py @@ -3,12 +3,13 @@ import numpy as np -from deepmd.pt.model.descriptor.se_a import ( - DescrptBlockSeA, -) from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, @@ -21,8 +22,31 @@ dtype = env.GLOBAL_PT_FLOAT_PRECISION +class TestAtomExcludeMask(unittest.TestCase): + def test_build_type_exclude_mask(self): + nf = 2 + nt = 3 + exclude_types = [0, 2] + atype = np.array( + [ + [0, 2, 1, 2, 0, 1, 0], + [1, 2, 0, 0, 2, 2, 1], + ], + dtype=np.int32, + ).reshape([nf, -1]) + expected_mask = np.array( + [ + [0, 0, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 1], + ] + ).reshape([nf, -1]) + des = AtomExcludeMask(nt, exclude_types=exclude_types) + mask = des(to_torch_tensor(atype)) + np.testing.assert_equal(to_numpy_array(mask), expected_mask) + + # to be merged with the tf test case -class TestExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): +class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self): TestCaseSingleFrameWithNlist.setUp(self) @@ -35,10 +59,8 @@ def test_build_type_exclude_mask(self): [0, 0, 1, 1, 1, 1, 1], ] ).reshape(self.nf, self.nloc, sum(self.sel)) - des = DescrptBlockSeA( - self.rcut, self.rcut_smth, self.sel, exclude_types=exclude_types - ).to(env.DEVICE) - mask = des.build_type_exclude_mask( + des = PairExcludeMask(self.nt, exclude_types=exclude_types).to(env.DEVICE) + mask = des( to_torch_tensor(self.nlist), to_torch_tensor(self.atype_ext), )