From 0a0ba0408c1e8f26cdd5029f4074a870a9b4ade4 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 5 Dec 2024 00:59:12 +0800 Subject: [PATCH] update bi g1 --- deepmd/dpmodel/descriptor/dpa2.py | 2 + deepmd/pt/model/descriptor/dpa2.py | 1 + deepmd/pt/model/descriptor/repformer_layer.py | 52 ++++++++++++++++++- deepmd/pt/model/descriptor/repformers.py | 10 ++++ deepmd/utils/argcheck.py | 6 +++ 5 files changed, 70 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 9c1ad1277f..4bc4b59c39 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -230,6 +230,7 @@ def __init__( ln_eps: Optional[float] = 1e-5, use_undirect_g2: bool = False, use_undirect_a: bool = False, + update_g1_bidirect: bool = False, ) -> None: r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor. @@ -351,6 +352,7 @@ def __init__( self.ln_eps = ln_eps self.use_undirect_g2 = use_undirect_g2 self.use_undirect_a = use_undirect_a + self.update_g1_bidirect = update_g1_bidirect def __getitem__(self, key): if hasattr(self, key): diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 3278495146..c59317ea49 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -256,6 +256,7 @@ def init_subclass_params(sub_data, sub_class): angle_only_cos=self.repformer_args.angle_only_cos, use_undirect_g2=self.repformer_args.use_undirect_g2, use_undirect_a=self.repformer_args.use_undirect_a, + update_g1_bidirect=self.repformer_args.update_g1_bidirect, seed=child_seed(seed, 1), ) self.rcsl_list = [ diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 04b322002a..1fb87edaa5 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -480,6 +480,7 @@ def __init__( angle_use_self_g2_padding: bool = True, use_undirect_g2: bool = False, use_undirect_a: bool = False, + update_g1_bidirect: bool = False, seed: Optional[Union[int, list[int]]] = None, ) -> None: super().__init__() @@ -536,6 +537,7 @@ def __init__( self.angle_use_self_g2_padding = angle_use_self_g2_padding self.use_undirect_g2 = use_undirect_g2 self.use_undirect_a = use_undirect_a + self.update_g1_bidirect = update_g1_bidirect assert update_residual_init in [ "norm", @@ -626,13 +628,22 @@ def __init__( g1_dim, precision=precision, seed=child_seed(seed, 11), - ) # need act + ) # need act # receive self.g1_edge_linear2 = MLPLayer( g1_dim, g1_dim, precision=precision, seed=child_seed(seed, 12), ) # need act + if self.update_g1_bidirect: + self.g1_edge_linear_send = MLPLayer( + self.edge_info_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 20), + ) # need act # send + else: + self.g1_edge_linear_send = None if self.update_style == "res_residual": self.g1_residual.append( get_residual( @@ -643,6 +654,16 @@ def __init__( seed=child_seed(seed, 13), ) ) + if self.update_g1_bidirect: + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 21), + ) + ) if not self.update_g2_has_edge: # g2 self mlp @@ -996,6 +1017,7 @@ def forward( angle_nlist: torch.Tensor, # nf x nloc x a_nnei angle_nlist_mask: torch.Tensor, # nf x nloc x a_nnei angle_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + nlist_loc: Optional[torch.Tensor] = None, ): """ Parameters @@ -1100,9 +1122,37 @@ def forward( assert self.g1_edge_linear1 is not None assert self.g1_edge_linear2 is not None # nb x nloc x nnei x ng1 + # receive g1_edge_info = self.act(self.g1_edge_linear1(edge_info)) * sw.unsqueeze(-1) g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei g1_update.append(g1_edge_update) + if self.update_g1_bidirect: + # send message + assert self.g1_edge_linear_send is not None + # nb x nloc x nnei x ng1 + g1_edge_info_send = self.act( + self.g1_edge_linear_send(edge_info) + ) * sw.unsqueeze(-1) + assert nlist_loc is not None + # nb x (nloc+1) x nnei x ng1 + scattered_message = torch.zeros( + size=[nb, nloc + 1, nnei, self.g1_dim], + device=g1_edge_info_send.device, + dtype=g1_edge_info_send.dtype, + ) + # nb x nloc x nnei x ng1 + scatter_index = nlist_loc.unsqueeze(-1).expand(-1, -1, -1, self.g1_dim) + # nb x nloc x nnei x ng1 + scattered_message = torch.scatter_reduce( + scattered_message, + dim=1, + index=scatter_index, + src=g1_edge_info_send, + reduce="sum", + )[:, :-1, :, :] + # nb x nloc x ng1 + g1_edge_update_send = torch.sum(scattered_message, dim=-2) / self.nnei + g1_update.append(g1_edge_update_send) assert self.linear2 is not None if not self.update_g2_has_edge: diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index e58854bc72..470728263d 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -130,6 +130,7 @@ def __init__( angle_only_cos: bool = False, use_undirect_g2: bool = False, use_undirect_a: bool = False, + update_g1_bidirect: bool = False, ) -> None: r""" The repformer descriptor block. @@ -273,6 +274,7 @@ def __init__( self.angle_only_cos = angle_only_cos self.use_undirect_g2 = use_undirect_g2 self.use_undirect_a = use_undirect_a + self.update_g1_bidirect = update_g1_bidirect if num_a % 2 != 1: raise ValueError(f"{num_a=} must be an odd integer") circular_harmonics_order = (num_a - 1) // 2 @@ -364,6 +366,7 @@ def __init__( g1_out_mlp=self.g1_out_mlp, use_undirect_g2=self.use_undirect_g2, use_undirect_a=self.use_undirect_a, + update_g1_bidirect=self.update_g1_bidirect, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -576,9 +579,15 @@ def forward( # nb x nall x ng1 if comm_dict is None: assert mapping is not None + nlist_loc = torch.gather( + mapping, index=nlist.reshape(nframes, -1), dim=1 + ).reshape(nframes, nloc, self.nnei) + nlist_loc = torch.where(nlist_mask, nlist_loc, nloc) mapping = ( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) ) + else: + nlist_loc = None for idx, ll in enumerate(self.layers): # g1: nb x nloc x ng1 # g1_ext: nb x nall x ng1 @@ -649,6 +658,7 @@ def forward( angle_nlist, angle_nlist_mask, angle_sw, + nlist_loc=nlist_loc, ) # nb x nloc x 3 x ng2 diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1ee5618a4b..4e56ffa155 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1252,6 +1252,12 @@ def dpa2_repformer_args(): optional=True, default=False, ), + Argument( + "update_g1_bidirect", + bool, + optional=True, + default=False, + ), Argument( "update_g2_has_a", bool,