diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 0de63bce4a..2d2312c4aa 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -59,6 +59,9 @@ DescrptBlockRepformers, RepformerLayer, ) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) class RepinitArgs: @@ -75,6 +78,11 @@ def __init__( activation_function="tanh", resnet_dt: bool = False, type_one_side: bool = False, + use_three_body: bool = False, + three_body_neuron: List[int] = [2, 4, 8], + three_body_sel: int = 40, + three_body_rcut: float = 4.0, + three_body_rcut_smth: float = 0.5, ): r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor. @@ -104,6 +112,19 @@ def __init__( Whether to use a "Timestep" in the skip connection. type_one_side : bool, optional Whether to use one-side type embedding. + use_three_body : bool, optional + Whether to concatenate three-body representation in the output descriptor. + three_body_neuron : list, optional + Number of neurons in each hidden layers of the three-body embedding net. + When two layers are of the same size or one layer is twice as large as the previous layer, + a skip connection is built. + three_body_sel : int, optional + Maximally possible number of selected neighbors in the three-body representation. + three_body_rcut : float, optional + The cut-off radius in the three-body representation. + three_body_rcut_smth : float, optional + Where to start smoothing in the three-body representation. + For example the 1/r term is smoothed from three_body_rcut to three_body_rcut_smth. """ self.rcut = rcut self.rcut_smth = rcut_smth @@ -116,6 +137,11 @@ def __init__( self.activation_function = activation_function self.resnet_dt = resnet_dt self.type_one_side = type_one_side + self.use_three_body = use_three_body + self.three_body_neuron = three_body_neuron + self.three_body_sel = three_body_sel + self.three_body_rcut = three_body_rcut + self.three_body_rcut_smth = three_body_rcut_smth def __getitem__(self, key): if hasattr(self, key): @@ -136,6 +162,11 @@ def serialize(self) -> dict: "activation_function": self.activation_function, "resnet_dt": self.resnet_dt, "type_one_side": self.type_one_side, + "use_three_body": self.use_three_body, + "three_body_neuron": self.three_body_neuron, + "three_body_sel": self.three_body_sel, + "three_body_rcut": self.three_body_rcut, + "three_body_rcut_smth": self.three_body_rcut_smth, } @classmethod @@ -172,6 +203,9 @@ def __init__( update_residual_init: str = "norm", set_davg_zero: bool = True, trainable_ln: bool = True, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, ln_eps: Optional[float] = 1e-5, ): r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor. @@ -236,6 +270,12 @@ def __init__( Set the normalization average to zero. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. """ @@ -265,6 +305,9 @@ def __init__( self.update_residual_init = update_residual_init self.set_davg_zero = set_davg_zero self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -304,6 +347,9 @@ def serialize(self) -> dict: "update_residual_init": self.update_residual_init, "set_davg_zero": self.set_davg_zero, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, } @@ -416,6 +462,27 @@ def init_subclass_params(sub_data, sub_class): type_one_side=self.repinit_args.type_one_side, seed=child_seed(seed, 0), ) + self.use_three_body = self.repinit_args.use_three_body + if self.use_three_body: + self.repinit_three_body = DescrptBlockSeTTebd( + self.repinit_args.three_body_rcut, + self.repinit_args.three_body_rcut_smth, + self.repinit_args.three_body_sel, + ntypes, + neuron=self.repinit_args.three_body_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + seed=child_seed(seed, 5), + ) + else: + self.repinit_three_body = None self.repformers = DescrptBlockRepformers( self.repformer_args.rcut, self.repformer_args.rcut_smth, @@ -448,9 +515,27 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, trainable_ln=self.repformer_args.trainable_ln, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, ln_eps=self.repformer_args.ln_eps, seed=child_seed(seed, 1), ) + self.rcsl_list = [ + (self.repformers.get_rcut(), self.repformers.get_nsel()), + (self.repinit.get_rcut(), self.repinit.get_nsel()), + ] + if self.use_three_body: + self.rcsl_list.append( + (self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel()) + ) + self.rcsl_list.sort() + for ii in range(1, len(self.rcsl_list)): + assert ( + self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1] + ), "rcut and sel are not in the same order" + self.rcut_list = [ii[0] for ii in self.rcsl_list] + self.nsel_list = [ii[1] for ii in self.rcsl_list] self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias self.type_map = type_map @@ -473,11 +558,16 @@ def init_subclass_params(sub_data, sub_class): self.trainable = trainable self.add_tebd_to_repinit_out = add_tebd_to_repinit_out - if self.repinit.dim_out == self.repformers.dim_in: + self.repinit_out_dim = self.repinit.dim_out + if self.repinit_args.use_three_body: + assert self.repinit_three_body is not None + self.repinit_out_dim += self.repinit_three_body.dim_out + + if self.repinit_out_dim == self.repformers.dim_in: self.g1_shape_tranform = Identity() else: self.g1_shape_tranform = NativeLayer( - self.repinit.dim_out, + self.repinit_out_dim, self.repformers.dim_in, bias=False, precision=precision, @@ -585,6 +675,7 @@ def change_type_map( self.ntypes = len(type_map) repinit = self.repinit repformers = self.repformers + repinit_three_body = self.repinit_three_body if has_new_type: # the avg and std of new types need to be updated extend_descrpt_stat( @@ -601,6 +692,14 @@ def change_type_map( if model_with_new_type_stat is not None else None, ) + if self.use_three_body: + extend_descrpt_stat( + repinit_three_body, + type_map, + des_with_stat=model_with_new_type_stat.repinit_three_body + if model_with_new_type_stat is not None + else None, + ) repinit.ntypes = self.ntypes repformers.ntypes = self.ntypes repinit.reinit_exclude(self.exclude_types) @@ -609,6 +708,11 @@ def change_type_map( repinit["dstd"] = repinit["dstd"][remap_index] repformers["davg"] = repformers["davg"][remap_index] repformers["dstd"] = repformers["dstd"][remap_index] + if self.use_three_body: + repinit_three_body.ntypes = self.ntypes + repinit_three_body.reinit_exclude(self.exclude_types) + repinit_three_body["davg"] = repinit_three_body["davg"][remap_index] + repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index] @property def dim_out(self): @@ -677,14 +781,15 @@ def call( The smooth switch function. shape: nf x nloc x nnei """ + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = coord_ext.reshape(nframes, -1).shape[1] // 3 # nlists nlist_dict = build_multiple_neighbor_list( coord_ext, nlist, - [self.repformers.get_rcut(), self.repinit.get_rcut()], - [self.repformers.get_nsel(), self.repinit.get_nsel()], + self.rcut_list, + self.nsel_list, ) # repinit g1_ext = self.type_embedding.call()[atype_ext] @@ -698,6 +803,21 @@ def call( g1_ext, mapping, ) + if use_three_body: + assert self.repinit_three_body is not None + g1_three_body, __, __, __, __ = self.repinit_three_body( + nlist_dict[ + get_multiple_nlist_key( + self.repinit_three_body.get_rcut(), + self.repinit_three_body.get_nsel(), + ) + ], + coord_ext, + atype_ext, + g1_ext, + mapping, + ) + g1 = np.concatenate([g1, g1_three_body], axis=-1) # linear to change shape g1 = self.g1_shape_tranform(g1) if self.add_tebd_to_repinit_out: @@ -726,10 +846,11 @@ def call( def serialize(self) -> dict: repinit = self.repinit repformers = self.repformers + repinit_three_body = self.repinit_three_body data = { "@class": "Descriptor", "type": "dpa2", - "@version": 2, + "@version": 3, "ntypes": self.ntypes, "repinit_args": self.repinit_args.serialize(), "repformer_args": self.repformer_args.serialize(), @@ -779,20 +900,53 @@ def serialize(self) -> dict: "repformers_variable": repformers_variable, } ) + if self.use_three_body: + repinit_three_body_variable = { + "embeddings": repinit_three_body.embeddings.serialize(), + "env_mat": EnvMat( + repinit_three_body.rcut, repinit_three_body.rcut_smth + ).serialize(), + "@variables": { + "davg": repinit_three_body["davg"], + "dstd": repinit_three_body["dstd"], + }, + } + if repinit_three_body.tebd_input_mode in ["strip"]: + repinit_three_body_variable.update( + { + "embeddings_strip": repinit_three_body.embeddings_strip.serialize() + } + ) + data.update( + { + "repinit_three_body_variable": repinit_three_body_variable, + } + ) return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA2": data = data.copy() - check_version_compatibility(data.pop("@version"), 2, 1) + version = data.pop("@version") + check_version_compatibility(version, 3, 1) data.pop("@class") data.pop("type") repinit_variable = data.pop("repinit_variable").copy() repformers_variable = data.pop("repformers_variable").copy() + repinit_three_body_variable = ( + data.pop("repinit_three_body_variable").copy() + if "repinit_three_body_variable" in data + else None + ) type_embedding = data.pop("type_embedding") g1_shape_tranform = data.pop("g1_shape_tranform") tebd_transform = data.pop("tebd_transform", None) add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False data["repinit"] = RepinitArgs(**data.pop("repinit_args")) data["repformer"] = RepformerArgs(**data.pop("repformer_args")) # compat with version 1 @@ -820,6 +974,21 @@ def deserialize(cls, data: dict) -> "DescrptDPA2": obj.repinit["davg"] = statistic_repinit["davg"] obj.repinit["dstd"] = statistic_repinit["dstd"] + if data["repinit"].use_three_body: + # deserialize repinit_three_body + statistic_repinit_three_body = repinit_three_body_variable.pop("@variables") + env_mat = repinit_three_body_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit_three_body.embeddings = NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit_three_body.embeddings_strip = NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings_strip") + ) + obj.repinit_three_body["davg"] = statistic_repinit_three_body["davg"] + obj.repinit_three_body["dstd"] = statistic_repinit_three_body["dstd"] + # deserialize repformers statistic_repformers = repformers_variable.pop("@variables") env_mat = repformers_variable.pop("env_mat") diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index bb84816d3d..7254f0bc3d 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -118,6 +118,12 @@ class DescrptBlockRepformers(NativeOP, DescriptorBlock): For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. seed : int, optional @@ -157,6 +163,9 @@ def __init__( env_protection: float = 0.0, precision: str = "float64", trainable_ln: bool = True, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, ): @@ -200,6 +209,9 @@ def __init__( self.env_protection = env_protection self.precision = precision self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -238,6 +250,9 @@ def __init__( trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -392,7 +407,15 @@ def call( ) # nf x nloc x 3 x ng2 - h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + h2g2 = _cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) # (nf x nloc) x ng2 x 3 rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw @@ -521,6 +544,7 @@ def _cal_hg( sw: np.ndarray, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, ) -> np.ndarray: """ Calculate the transposed rotation matrix. @@ -540,6 +564,8 @@ def _cal_hg( Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. Returns ------- @@ -555,12 +581,20 @@ def _cal_hg( g = _apply_nlist_mask(g, nlist_mask) if not smooth: # nf x nloc - invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + if not use_sqrt_nnei: + invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + else: + invnnei = 1.0 / (epsilon + np.sqrt(np.sum(nlist_mask, axis=-1))) # nf x nloc x 1 x 1 invnnei = invnnei[:, :, np.newaxis, np.newaxis] else: g = _apply_switch(g, sw) - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + else: + invnnei = (1.0 / (float(nnei) ** 0.5)) * np.ones( + (nf, nloc, 1, 1), dtype=g.dtype + ) # nf x nloc x 3 x ng hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei return hg @@ -601,6 +635,7 @@ def symmetrization_op( axis_neuron: int, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, ) -> np.ndarray: """ Symmetrization operator to obtain atomic invariant rep. @@ -622,6 +657,8 @@ def symmetrization_op( Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. Returns ------- @@ -633,7 +670,15 @@ def symmetrization_op( # msk: nf x nloc x nnei nf, nloc, nnei, _ = g.shape # nf x nloc x 3 x ng - hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + hg = _cal_hg( + g, + h, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=use_sqrt_nnei, + ) # nf x nloc x (axis_neuron x ng) grrg = _cal_grrg(hg, axis_neuron) return grrg @@ -1083,6 +1128,9 @@ def __init__( smooth: bool = True, precision: str = "float64", trainable_ln: bool = True, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, ): @@ -1120,6 +1168,9 @@ def __init__( self.g1_dim = g1_dim self.g2_dim = g2_dim self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp self.ln_eps = ln_eps self.precision = precision @@ -1177,14 +1228,52 @@ def __init__( seed=child_seed(seed, 3), ) ) - if self.update_g1_has_conv: - self.proj_g1g2 = NativeLayer( + if self.g1_out_mlp: + self.g1_self_mlp = NativeLayer( + g1_dim, g1_dim, - g2_dim, - bias=False, precision=precision, - seed=child_seed(seed, 4), + seed=child_seed(seed, 15), ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = NativeLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = NativeLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) if self.update_g2_has_g1g1: self.proj_g1g1g2 = NativeLayer( g1_dim, @@ -1270,12 +1359,12 @@ def __init__( ) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g1d + ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: ret += g2d * ax if self.update_g1_has_drrd: ret += g1d * ax - if self.update_g1_has_conv: + if self.update_g1_has_conv and not self.g1_out_conv: ret += g2d return ret @@ -1325,9 +1414,13 @@ def _update_g1_conv( nf, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) - # nf x nloc x nnei x ng2 + if not self.g1_out_conv: + # gg1 : nf x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + else: + # gg1 : nf x nloc x nnei x ng1 + gg1 = gg1.reshape(nf, nloc, nnei, ng1) + # nf x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth @@ -1338,8 +1431,14 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) - # nf x nloc x ng2 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + if not self.g1_out_conv: + # nf x nloc x ng2 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + else: + # nf x nloc x ng1 + g2 = self.proj_g1g2(g2).reshape(nf, nloc, nnei, ng1) + # nb x nloc x ng1 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei return g1_11 def _update_g2_g1g1( @@ -1412,7 +1511,11 @@ def call( g2_update: List[np.ndarray] = [g2] h2_update: List[np.ndarray] = [h2] g1_update: List[np.ndarray] = [g1] - g1_mlp: List[np.ndarray] = [g1] + g1_mlp: List[np.ndarray] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) if cal_gg1: gg1 = _make_nei_g1(g1_ext, nlist) @@ -1454,7 +1557,11 @@ def call( if self.update_g1_has_conv: assert gg1 is not None - g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) if self.update_g1_has_grrg: g1_mlp.append( @@ -1466,6 +1573,7 @@ def call( self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) ) @@ -1480,6 +1588,7 @@ def call( self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) ) @@ -1586,6 +1695,9 @@ def serialize(self) -> dict: "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, "linear1": self.linear1.serialize(), } @@ -1633,6 +1745,12 @@ def serialize(self) -> dict: "loc_attn": self.loc_attn.serialize(), } ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) if self.update_style == "res_residual": data.update( { @@ -1663,6 +1781,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) @@ -1672,6 +1791,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_lm = data.pop("attn2_lm", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) g1_residual = data.pop("g1_residual", []) g2_residual = data.pop("g2_residual", []) h2_residual = data.pop("h2_residual", []) @@ -1701,6 +1821,9 @@ def deserialize(cls, data: dict) -> "RepformerLayer": if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = NativeLayer.deserialize(g1_self_mlp) if update_style == "res_residual": obj.g1_residual = g1_residual obj.g2_residual = g2_residual diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 7e5262e275..cabbdae175 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -69,6 +69,9 @@ from .se_atten import ( DescrptBlockSeAtten, ) +from .se_t_tebd import ( + DescrptBlockSeTTebd, +) @BaseDescriptor.register("dpa2") @@ -177,6 +180,27 @@ def init_subclass_params(sub_data, sub_class): type_one_side=self.repinit_args.type_one_side, seed=child_seed(seed, 0), ) + self.use_three_body = self.repinit_args.use_three_body + if self.use_three_body: + self.repinit_three_body = DescrptBlockSeTTebd( + self.repinit_args.three_body_rcut, + self.repinit_args.three_body_rcut_smth, + self.repinit_args.three_body_sel, + ntypes, + neuron=self.repinit_args.three_body_neuron, + tebd_dim=self.repinit_args.tebd_dim, + tebd_input_mode=self.repinit_args.tebd_input_mode, + set_davg_zero=self.repinit_args.set_davg_zero, + exclude_types=exclude_types, + env_protection=env_protection, + activation_function=self.repinit_args.activation_function, + precision=precision, + resnet_dt=self.repinit_args.resnet_dt, + smooth=smooth, + seed=child_seed(seed, 5), + ) + else: + self.repinit_three_body = None self.repformers = DescrptBlockRepformers( self.repformer_args.rcut, self.repformer_args.rcut_smth, @@ -210,9 +234,27 @@ def init_subclass_params(sub_data, sub_class): precision=precision, trainable_ln=self.repformer_args.trainable_ln, ln_eps=self.repformer_args.ln_eps, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, seed=child_seed(seed, 1), old_impl=old_impl, ) + self.rcsl_list = [ + (self.repformers.get_rcut(), self.repformers.get_nsel()), + (self.repinit.get_rcut(), self.repinit.get_nsel()), + ] + if self.use_three_body: + self.rcsl_list.append( + (self.repinit_three_body.get_rcut(), self.repinit_three_body.get_nsel()) + ) + self.rcsl_list.sort() + for ii in range(1, len(self.rcsl_list)): + assert ( + self.rcsl_list[ii - 1][1] <= self.rcsl_list[ii][1] + ), "rcut and sel are not in the same order" + self.rcut_list = [ii[0] for ii in self.rcsl_list] + self.nsel_list = [ii[1] for ii in self.rcsl_list] self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias self.type_map = type_map @@ -233,11 +275,16 @@ def init_subclass_params(sub_data, sub_class): self.trainable = trainable self.add_tebd_to_repinit_out = add_tebd_to_repinit_out - if self.repinit.dim_out == self.repformers.dim_in: + self.repinit_out_dim = self.repinit.dim_out + if self.repinit_args.use_three_body: + assert self.repinit_three_body is not None + self.repinit_out_dim += self.repinit_three_body.dim_out + + if self.repinit_out_dim == self.repformers.dim_in: self.g1_shape_tranform = Identity() else: self.g1_shape_tranform = MLPLayer( - self.repinit.dim_out, + self.repinit_out_dim, self.repformers.dim_in, bias=False, precision=precision, @@ -383,6 +430,7 @@ def change_type_map( self.ntypes = len(type_map) repinit = self.repinit repformers = self.repformers + repinit_three_body = self.repinit_three_body if has_new_type: # the avg and std of new types need to be updated extend_descrpt_stat( @@ -399,6 +447,14 @@ def change_type_map( if model_with_new_type_stat is not None else None, ) + if self.use_three_body: + extend_descrpt_stat( + repinit_three_body, + type_map, + des_with_stat=model_with_new_type_stat.repinit_three_body + if model_with_new_type_stat is not None + else None, + ) repinit.ntypes = self.ntypes repformers.ntypes = self.ntypes repinit.reinit_exclude(self.exclude_types) @@ -407,6 +463,11 @@ def change_type_map( repinit["dstd"] = repinit["dstd"][remap_index] repformers["davg"] = repformers["davg"][remap_index] repformers["dstd"] = repformers["dstd"][remap_index] + if self.use_three_body: + repinit_three_body.ntypes = self.ntypes + repinit_three_body.reinit_exclude(self.exclude_types) + repinit_three_body["davg"] = repinit_three_body["davg"][remap_index] + repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index] @property def dim_out(self): @@ -461,10 +522,11 @@ def get_stat_mean_and_stddev(self) -> Tuple[List[torch.Tensor], List[torch.Tenso def serialize(self) -> dict: repinit = self.repinit repformers = self.repformers + repinit_three_body = self.repinit_three_body data = { "@class": "Descriptor", "type": "dpa2", - "@version": 2, + "@version": 3, "ntypes": self.ntypes, "repinit_args": self.repinit_args.serialize(), "repformer_args": self.repformer_args.serialize(), @@ -514,20 +576,53 @@ def serialize(self) -> dict: "repformers_variable": repformers_variable, } ) + if self.use_three_body: + repinit_three_body_variable = { + "embeddings": repinit_three_body.filter_layers.serialize(), + "env_mat": DPEnvMat( + repinit_three_body.rcut, repinit_three_body.rcut_smth + ).serialize(), + "@variables": { + "davg": to_numpy_array(repinit_three_body["davg"]), + "dstd": to_numpy_array(repinit_three_body["dstd"]), + }, + } + if repinit_three_body.tebd_input_mode in ["strip"]: + repinit_three_body_variable.update( + { + "embeddings_strip": repinit_three_body.filter_layers_strip.serialize() + } + ) + data.update( + { + "repinit_three_body_variable": repinit_three_body_variable, + } + ) return data @classmethod def deserialize(cls, data: dict) -> "DescrptDPA2": data = data.copy() - check_version_compatibility(data.pop("@version"), 2, 1) + version = data.pop("@version") + check_version_compatibility(version, 3, 1) data.pop("@class") data.pop("type") repinit_variable = data.pop("repinit_variable").copy() repformers_variable = data.pop("repformers_variable").copy() + repinit_three_body_variable = ( + data.pop("repinit_three_body_variable").copy() + if "repinit_three_body_variable" in data + else None + ) type_embedding = data.pop("type_embedding") g1_shape_tranform = data.pop("g1_shape_tranform") tebd_transform = data.pop("tebd_transform", None) add_tebd_to_repinit_out = data["add_tebd_to_repinit_out"] + if version < 3: + # compat with old version + data["repformer_args"]["use_sqrt_nnei"] = False + data["repformer_args"]["g1_out_conv"] = False + data["repformer_args"]["g1_out_mlp"] = False data["repinit"] = RepinitArgs(**data.pop("repinit_args")) data["repformer"] = RepformerArgs(**data.pop("repformer_args")) # compat with version 1 @@ -560,6 +655,23 @@ def t_cvt(xx): obj.repinit["davg"] = t_cvt(statistic_repinit["davg"]) obj.repinit["dstd"] = t_cvt(statistic_repinit["dstd"]) + if data["repinit"].use_three_body: + # deserialize repinit_three_body + statistic_repinit_three_body = repinit_three_body_variable.pop("@variables") + env_mat = repinit_three_body_variable.pop("env_mat") + tebd_input_mode = data["repinit"].tebd_input_mode + obj.repinit_three_body.filter_layers = NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings") + ) + if tebd_input_mode in ["strip"]: + obj.repinit_three_body.filter_layers_strip = ( + NetworkCollection.deserialize( + repinit_three_body_variable.pop("embeddings_strip") + ) + ) + obj.repinit_three_body["davg"] = t_cvt(statistic_repinit_three_body["davg"]) + obj.repinit_three_body["dstd"] = t_cvt(statistic_repinit_three_body["dstd"]) + # deserialize repformers statistic_repformers = repformers_variable.pop("@variables") env_mat = repformers_variable.pop("env_mat") @@ -614,14 +726,15 @@ def forward( The smooth switch function. shape: nf x nloc x nnei """ + use_three_body = self.use_three_body nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 # nlists nlist_dict = build_multiple_neighbor_list( extended_coord, nlist, - [self.repformers.get_rcut(), self.repinit.get_rcut()], - [self.repformers.get_nsel(), self.repinit.get_nsel()], + self.rcut_list, + self.nsel_list, ) # repinit g1_ext = self.type_embedding(extended_atype) @@ -635,6 +748,21 @@ def forward( g1_ext, mapping, ) + if use_three_body: + assert self.repinit_three_body is not None + g1_three_body, __, __, __, __ = self.repinit_three_body( + nlist_dict[ + get_multiple_nlist_key( + self.repinit_three_body.get_rcut(), + self.repinit_three_body.get_nsel(), + ) + ], + extended_coord, + extended_atype, + g1_ext, + mapping, + ) + g1 = torch.cat([g1, g1_three_body], dim=-1) # linear to change shape g1 = self.g1_shape_tranform(g1) if self.add_tebd_to_repinit_out: diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 85a9800c73..579dc0c81e 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -599,6 +599,9 @@ def __init__( precision: str = "float64", trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, seed: Optional[Union[int, List[int]]] = None, ): super().__init__() @@ -638,6 +641,9 @@ def __init__( self.ln_eps = ln_eps self.precision = precision self.seed = seed + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp assert update_residual_init in [ "norm", @@ -693,14 +699,52 @@ def __init__( seed=child_seed(seed, 3), ) ) - if self.update_g1_has_conv: - self.proj_g1g2 = MLPLayer( + if self.g1_out_mlp: + self.g1_self_mlp = MLPLayer( + g1_dim, g1_dim, - g2_dim, - bias=False, precision=precision, - seed=child_seed(seed, 4), + seed=child_seed(seed, 15), ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = MLPLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) if self.update_g2_has_g1g1: self.proj_g1g1g2 = MLPLayer( g1_dim, @@ -790,12 +834,12 @@ def __init__( self.h2_residual = nn.ParameterList(self.h2_residual) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g1d + ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: ret += g2d * ax if self.update_g1_has_drrd: ret += g1d * ax - if self.update_g1_has_conv: + if self.update_g1_has_conv and not self.g1_out_conv: ret += g2d return ret @@ -845,9 +889,12 @@ def _update_g1_conv( nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nb x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) - # nb x nloc x nnei x ng2 + if not self.g1_out_conv: + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + else: + gg1 = gg1.view(nb, nloc, nnei, ng1) + # nb x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth @@ -861,8 +908,13 @@ def _update_g1_conv( invnnei = (1.0 / float(nnei)) * torch.ones( (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) - # nb x nloc x ng2 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + if not self.g1_out_conv: + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + else: + g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1) + # nb x nloc x ng1 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 @staticmethod @@ -873,6 +925,7 @@ def _cal_hg( sw: torch.Tensor, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = True, ) -> torch.Tensor: """ Calculate the transposed rotation matrix. @@ -908,14 +961,25 @@ def _cal_hg( if not smooth: # nb x nloc # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + if not use_sqrt_nnei: + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + else: + invnnei = 1.0 / ( + epsilon + torch.sqrt(torch.sum(nlist_mask.type_as(g2), dim=-1)) + ) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: g2 = _apply_switch(g2, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device - ) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device + ) + else: + invnnei = torch.rsqrt( + float(nnei) + * torch.ones((nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device) + ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei return h2g2 @@ -988,7 +1052,15 @@ def symmetrization_op( # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape # nb x nloc x 3 x ng2 - h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + h2g2 = self._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) # nb x nloc x (axisxng2) g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13 @@ -1063,7 +1135,11 @@ def forward( g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] g1_update: List[torch.Tensor] = [g1] - g1_mlp: List[torch.Tensor] = [g1] + g1_mlp: List[torch.Tensor] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) if cal_gg1: gg1 = _make_nei_g1(g1_ext, nlist) @@ -1105,7 +1181,11 @@ def forward( if self.update_g1_has_conv: assert gg1 is not None - g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) if self.update_g1_has_grrg: g1_mlp.append( @@ -1242,6 +1322,9 @@ def serialize(self) -> dict: "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, "linear1": self.linear1.serialize(), } @@ -1289,6 +1372,12 @@ def serialize(self) -> dict: "loc_attn": self.loc_attn.serialize(), } ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) if self.update_style == "res_residual": data.update( { @@ -1319,6 +1408,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) @@ -1328,6 +1418,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_lm = data.pop("attn2_lm", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) g1_residual = data.pop("g1_residual", []) g2_residual = data.pop("g2_residual", []) h2_residual = data.pop("h2_residual", []) @@ -1357,6 +1448,9 @@ def deserialize(cls, data: dict) -> "RepformerLayer": if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = MLPLayer.deserialize(g1_self_mlp) if update_style == "res_residual": for ii, t in enumerate(obj.g1_residual): t.data = to_torch_tensor(g1_residual[ii]) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bc8c331ec3..a9e4ef7893 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -105,6 +105,9 @@ def __init__( trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, + use_sqrt_nnei: bool = True, + g1_out_conv: bool = True, + g1_out_mlp: bool = True, old_impl: bool = False, ): r""" @@ -182,6 +185,12 @@ def __init__( For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. seed : int, optional @@ -222,6 +231,9 @@ def __init__( self.direct_dist = direct_dist self.act = ActivationFn(activation_function) self.smooth = smooth + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection @@ -296,6 +308,9 @@ def __init__( trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -500,7 +515,13 @@ def forward( # nb x nloc x 3 x ng2 h2g2 = RepformerLayer._cal_hg( - g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) diff --git a/deepmd/pt/model/descriptor/se_t_tebd.py b/deepmd/pt/model/descriptor/se_t_tebd.py index 18569d2f18..774a9154de 100644 --- a/deepmd/pt/model/descriptor/se_t_tebd.py +++ b/deepmd/pt/model/descriptor/se_t_tebd.py @@ -557,7 +557,6 @@ def __init__( else: self.embd_input_dim = 1 - self.filter_layers_old = None self.filter_layers = None self.filter_layers_strip = None filter_layers = NetworkCollection( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index ddcfc4a863..c2f483e715 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -965,6 +965,17 @@ def dpa2_repinit_args(): doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." doc_type_one_side = r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' + doc_use_three_body = ( + "Whether to concatenate three-body representation in the output descriptor." + ) + doc_three_body_neuron = ( + "Number of neurons in each hidden layers of the three-body embedding net." + "When two layers are of the same size or one layer is twice as large as the previous layer, " + "a skip connection is built." + ) + doc_three_body_sel = "Maximally possible number of selected neighbors in the three-body representation." + doc_three_body_rcut = "The cut-off radius in the three-body representation." + doc_three_body_rcut_smth = "Where to start smoothing in the three-body representation. For example the 1/r term is smoothed from `three_body_rcut` to `three_body_rcut_smth`." return [ # repinit args @@ -1027,6 +1038,37 @@ def dpa2_repinit_args(): default=False, doc=doc_resnet_dt, ), + Argument( + "use_three_body", + bool, + optional=True, + default=False, + doc=doc_use_three_body, + ), + Argument( + "three_body_neuron", + list, + optional=True, + default=[2, 4, 8], + doc=doc_three_body_neuron, + ), + Argument( + "three_body_rcut", + float, + optional=True, + default=4.0, + doc=doc_three_body_rcut, + ), + Argument( + "three_body_rcut_smth", + float, + optional=True, + default=0.5, + doc=doc_three_body_rcut_smth, + ), + Argument( + "three_body_sel", int, optional=True, default=40, doc=doc_three_body_sel + ), ] @@ -1047,6 +1089,9 @@ def dpa2_repformer_args(): doc_update_g1_has_attn = "Update the g1 rep with the localized self-attention." doc_update_g2_has_g1g1 = "Update the g2 rep with the g1xg1 term." doc_update_g2_has_attn = "Update the g2 rep with the gated self-attention." + doc_use_sqrt_nnei = "Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly." + doc_g1_out_conv = "Whether to put the convolutional update of g1 separately outside the concatenated MLP update." + doc_g1_out_mlp = "Whether to put the self MLP update of g1 separately outside the concatenated MLP update." doc_update_h2 = "Update the h2 rep." doc_attn1_hidden = ( "The hidden dimension of localized self-attention to update the g1 rep." @@ -1167,6 +1212,27 @@ def dpa2_repformer_args(): default=True, doc=doc_update_g2_has_attn, ), + Argument( + "use_sqrt_nnei", + bool, + optional=True, + default=True, + doc=doc_use_sqrt_nnei, + ), + Argument( + "g1_out_conv", + bool, + optional=True, + default=True, + doc=doc_g1_out_conv, + ), + Argument( + "g1_out_mlp", + bool, + optional=True, + default=True, + doc=doc_g1_out_mlp, + ), Argument( "update_h2", bool, diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 9b88b4238a..d12e1094f0 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -37,7 +37,7 @@ RepinitArgs, ) from deepmd.utils.argcheck import ( - descrpt_se_atten_args, + descrpt_dpa2_args, ) @@ -45,6 +45,7 @@ ("concat", "strip"), # repinit_tebd_input_mode (True,), # repinit_set_davg_zero (False,), # repinit_type_one_side + (True, False), # repinit_use_three_body (True, False), # repformer_direct_dist (True,), # repformer_update_g1_has_conv (True,), # repformer_update_g1_has_drrd @@ -59,6 +60,9 @@ (True,), # repformer_set_davg_zero (True,), # repformer_trainable_ln (1e-5,), # repformer_ln_eps + (True,), # repformer_use_sqrt_nnei + (True,), # repformer_g1_out_conv + (True,), # repformer_g1_out_mlp (True, False), # smooth ([], [[0, 1]]), # exclude_types ("float64",), # precision @@ -73,6 +77,7 @@ def data(self) -> dict: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -87,6 +92,9 @@ def data(self) -> dict: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -109,6 +117,10 @@ def data(self) -> dict: "set_davg_zero": repinit_set_davg_zero, "activation_function": "tanh", "type_one_side": repinit_type_one_side, + "use_three_body": repinit_use_three_body, + "three_body_sel": 8, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, } ), # kwargs for repformer @@ -141,6 +153,9 @@ def data(self) -> dict: "set_davg_zero": True, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, + "use_sqrt_nnei": repformer_use_sqrt_nnei, + "g1_out_conv": repformer_g1_out_conv, + "g1_out_mlp": repformer_g1_out_mlp, } ), # kwargs for descriptor @@ -162,6 +177,7 @@ def skip_pt(self) -> bool: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -176,6 +192,9 @@ def skip_pt(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -191,6 +210,7 @@ def skip_dp(self) -> bool: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -205,6 +225,9 @@ def skip_dp(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -220,6 +243,7 @@ def skip_tf(self) -> bool: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -234,6 +258,9 @@ def skip_tf(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -246,7 +273,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT - args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) def setUp(self): CommonTest.setUp(self) @@ -285,6 +312,7 @@ def setUp(self): repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -299,6 +327,9 @@ def setUp(self): repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -347,6 +378,7 @@ def rtol(self) -> float: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -361,6 +393,9 @@ def rtol(self) -> float: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -382,6 +417,7 @@ def atol(self) -> float: repinit_tebd_input_mode, repinit_set_davg_zero, repinit_type_one_side, + repinit_use_three_body, repformer_update_g1_has_conv, repformer_direct_dist, repformer_update_g1_has_drrd, @@ -396,6 +432,9 @@ def atol(self) -> float: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, diff --git a/source/tests/pt/model/models/dpa2.json b/source/tests/pt/model/models/dpa2.json index ca1948492a..7495f5d78a 100644 --- a/source/tests/pt/model/models/dpa2.json +++ b/source/tests/pt/model/models/dpa2.json @@ -37,7 +37,10 @@ "update_g1_has_attn": true, "update_g2_has_g1g1": true, "update_g2_has_attn": true, - "attn2_has_gate": true + "attn2_has_gate": true, + "use_sqrt_nnei": false, + "g1_out_conv": false, + "g1_out_mlp": false }, "add_tebd_to_repinit_out": false }, diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 6d3b6e182d..f11be532cb 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -62,6 +62,7 @@ def test_consistency( sm, prec, ect, + ns, ) in itertools.product( ["concat", "strip"], # repinit_tebd_input_mode [ @@ -70,8 +71,12 @@ def test_consistency( [True, False], # repformer_update_g1_has_conv [True, False], # repformer_update_g1_has_drrd [True, False], # repformer_update_g1_has_grrg - [True, False], # repformer_update_g1_has_attn - [True, False], # repformer_update_g2_has_g1g1 + [ + False, + ], # repformer_update_g1_has_attn + [ + False, + ], # repformer_update_g2_has_g1g1 [True, False], # repformer_update_g2_has_attn [ False, @@ -83,10 +88,18 @@ def test_consistency( [ True, ], # repformer_set_davg_zero - [True, False], # smooth + [ + True, + ], # smooth ["float64"], # precision [False, True], # use_econf_tebd + [ + False, + True, + ], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) ): + if ns and not rp1d and not rp1g: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -121,6 +134,9 @@ def test_consistency( attn2_has_gate=rp2gate, update_style=rus, set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, ) # dpa2 new impl @@ -174,7 +190,7 @@ def test_consistency( atol=atol, ) # old impl - if prec == "float64" and rus == "res_avg" and ect is False: + if prec == "float64" and rus == "res_avg" and ect is False and ns is False: dd3 = DescrptDPA2( self.nt, repinit=repinit, @@ -239,6 +255,7 @@ def test_jit( sm, prec, ect, + ns, ) in itertools.product( ["concat", "strip"], # repinit_tebd_input_mode [ @@ -277,6 +294,7 @@ def test_jit( ], # smooth ["float64"], # precision [False, True], # use_econf_tebd + [True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -310,6 +328,9 @@ def test_jit( attn2_has_gate=rp2gate, update_style=rus, set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, ) # dpa2 new impl diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 424dd2ea39..256bea74f8 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -323,6 +323,7 @@ def DescriptorParamDPA2( repinit_tebd_input_mode="concat", repinit_set_davg_zero=False, repinit_type_one_side=False, + repinit_use_three_body=False, repformer_direct_dist=False, repformer_update_g1_has_conv=True, repformer_update_g1_has_drrd=True, @@ -337,6 +338,9 @@ def DescriptorParamDPA2( repformer_set_davg_zero=False, repformer_trainable_ln=True, repformer_ln_eps=1e-5, + repformer_use_sqrt_nnei=False, + repformer_g1_out_conv=False, + repformer_g1_out_mlp=False, smooth=True, add_tebd_to_repinit_out=True, use_econf_tebd=False, @@ -360,6 +364,10 @@ def DescriptorParamDPA2( "set_davg_zero": repinit_set_davg_zero, "activation_function": "tanh", "type_one_side": repinit_type_one_side, + "use_three_body": repinit_use_three_body, + "three_body_sel": min(sum(sel) // 2, 10), + "three_body_rcut": rcut / 2, + "three_body_rcut_smth": rcut_smth / 2, } ), # kwargs for repformer @@ -392,6 +400,9 @@ def DescriptorParamDPA2( "set_davg_zero": repformer_set_davg_zero, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, + "use_sqrt_nnei": repformer_use_sqrt_nnei, + "g1_out_conv": repformer_g1_out_conv, + "g1_out_mlp": repformer_g1_out_mlp, } ), # kwargs for descriptor @@ -417,6 +428,7 @@ def DescriptorParamDPA2( "repinit_tebd_input_mode": ("concat", "strip"), "repinit_set_davg_zero": (True,), "repinit_type_one_side": (False,), + "repinit_use_three_body": (True, False), "repformer_direct_dist": (False,), "repformer_update_g1_has_conv": (True,), "repformer_update_g1_has_drrd": (True,), @@ -431,6 +443,9 @@ def DescriptorParamDPA2( "repformer_set_davg_zero": (True,), "repformer_trainable_ln": (True,), "repformer_ln_eps": (1e-5,), + "repformer_use_sqrt_nnei": (True,), + "repformer_g1_out_conv": (True,), + "repformer_g1_out_mlp": (True,), "smooth": (True, False), "exclude_types": ([], [[0, 1]]), "precision": ("float64",),