From a3acd36e94f6ab0ff5fdf1fd1bd69b5ea658783c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:39:10 +0800 Subject: [PATCH] add mulg1 message --- deepmd/dpmodel/descriptor/dpa2.py | 2 ++ deepmd/pt/model/descriptor/dpa2.py | 1 + deepmd/pt/model/descriptor/repformer_layer.py | 36 +++++++++++++++++++ deepmd/pt/model/descriptor/repformers.py | 3 ++ deepmd/utils/argcheck.py | 6 ++++ 5 files changed, 48 insertions(+) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 8035bc8c78..9b08554af7 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -239,6 +239,7 @@ def __init__( update_g1_has_ar: bool = False, update_g2_has_arra: bool = False, compress_a: bool = False, + g1_bi_message: bool = False, ) -> None: r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor. @@ -369,6 +370,7 @@ def __init__( self.pipeline_update = pipeline_update self.g1_mess_mulmlp = g1_mess_mulmlp self.compress_a = compress_a + self.g1_bi_message = g1_bi_message def __getitem__(self, key): if hasattr(self, key): diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index ab67462a16..7c49dc92af 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -264,6 +264,7 @@ def init_subclass_params(sub_data, sub_class): update_g1_has_ar=self.repformer_args.update_g1_has_ar, update_g2_has_arra=self.repformer_args.update_g2_has_arra, compress_a=self.repformer_args.compress_a, + g1_bi_message=self.repformer_args.g1_bi_message, seed=child_seed(seed, 1), ) self.no_repinit = self.repformer_args.no_repinit diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 18169791d0..d42e5898fb 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -488,6 +488,7 @@ def __init__( update_g1_has_ar: bool = False, update_g2_has_arra: bool = False, compress_a: bool = False, + g1_bi_message: bool = False, seed: Optional[Union[int, list[int]]] = None, ) -> None: super().__init__() @@ -552,6 +553,7 @@ def __init__( self.update_g1_has_ar = update_g1_has_ar self.update_g2_has_arra = update_g2_has_arra self.compress_a = compress_a + self.g1_bi_message = g1_bi_message self.prec = PRECISION_DICT[precision] self.g1_layernorm = None self.g2_layernorm = None @@ -680,6 +682,17 @@ def __init__( ) # need act # send else: self.g1_edge_linear_send = None + + if self.g1_bi_message: + self.g1_edge_linear_receive_head2 = MLPLayer( + self.edge_info_dim, + g1_dim, + precision=precision, + seed=child_seed(seed, 22), + ) # need act # receive 2 + else: + self.g1_edge_linear_receive_head2 = None + if self.update_style == "res_residual": self.g1_residual.append( get_residual( @@ -700,6 +713,16 @@ def __init__( seed=child_seed(seed, 21), ) ) + if self.g1_bi_message: + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 23), + ) + ) # angle for g1 if self.has_angle and self.update_g1_has_ar: @@ -1241,6 +1264,19 @@ def forward( g1_edge_info = g1_edge_info * sw.unsqueeze(-1) g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei g1_update.append(g1_edge_update) + + if self.g1_bi_message: + # reveive multihead + assert self.g1_edge_linear_receive_head2 is not None + g1_edge_info_reveive2 = self.act( + self.g1_edge_linear_receive_head2(edge_info) + ) + g1_edge_info_reveive2 = g1_edge_info_reveive2 * sw.unsqueeze(-1) + g1_edge_update_reveive2 = ( + torch.sum(g1_edge_info_reveive2, dim=-2) / self.nnei + ) + g1_update.append(g1_edge_update_reveive2) + if self.update_g1_bidirect: # send message assert self.g1_edge_linear_send is not None diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 0f0a229c4c..8b7240e690 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -138,6 +138,7 @@ def __init__( update_g1_has_ar: bool = False, update_g2_has_arra: bool = False, compress_a: bool = False, + g1_bi_message: bool = False, ) -> None: r""" The repformer descriptor block. @@ -289,6 +290,7 @@ def __init__( self.update_g1_has_ar = update_g1_has_ar self.update_g2_has_arra = update_g2_has_arra self.compress_a = compress_a + self.g1_bi_message = g1_bi_message if num_a % 2 != 1: raise ValueError(f"{num_a=} must be an odd integer") circular_harmonics_order = (num_a - 1) // 2 @@ -396,6 +398,7 @@ def __init__( update_g1_has_ar=self.update_g1_has_ar, update_g2_has_arra=self.update_g2_has_arra, compress_a=self.compress_a, + g1_bi_message=self.g1_bi_message, seed=child_seed(child_seed(seed, 1), ii), ) ) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7fe749a18e..74cc62e830 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1210,6 +1210,12 @@ def dpa2_repformer_args(): optional=True, default=False, ), + Argument( + "g1_bi_message", + bool, + optional=True, + default=False, + ), Argument( "a_dim", int,