Skip to content

Commit

Permalink
add mulg1 message
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 20, 2024
1 parent a5ec73f commit a3acd36
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def __init__(
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
g1_bi_message: bool = False,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -369,6 +370,7 @@ def __init__(
self.pipeline_update = pipeline_update
self.g1_mess_mulmlp = g1_mess_mulmlp
self.compress_a = compress_a
self.g1_bi_message = g1_bi_message

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def init_subclass_params(sub_data, sub_class):
update_g1_has_ar=self.repformer_args.update_g1_has_ar,
update_g2_has_arra=self.repformer_args.update_g2_has_arra,
compress_a=self.repformer_args.compress_a,
g1_bi_message=self.repformer_args.g1_bi_message,
seed=child_seed(seed, 1),
)
self.no_repinit = self.repformer_args.no_repinit
Expand Down
36 changes: 36 additions & 0 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def __init__(
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
g1_bi_message: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -552,6 +553,7 @@ def __init__(
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
self.compress_a = compress_a
self.g1_bi_message = g1_bi_message
self.prec = PRECISION_DICT[precision]
self.g1_layernorm = None
self.g2_layernorm = None
Expand Down Expand Up @@ -680,6 +682,17 @@ def __init__(
) # need act # send
else:
self.g1_edge_linear_send = None

if self.g1_bi_message:
self.g1_edge_linear_receive_head2 = MLPLayer(
self.edge_info_dim,
g1_dim,
precision=precision,
seed=child_seed(seed, 22),
) # need act # receive 2
else:
self.g1_edge_linear_receive_head2 = None

if self.update_style == "res_residual":
self.g1_residual.append(
get_residual(
Expand All @@ -700,6 +713,16 @@ def __init__(
seed=child_seed(seed, 21),
)
)
if self.g1_bi_message:
self.g1_residual.append(
get_residual(
g1_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 23),
)
)

# angle for g1
if self.has_angle and self.update_g1_has_ar:
Expand Down Expand Up @@ -1241,6 +1264,19 @@ def forward(
g1_edge_info = g1_edge_info * sw.unsqueeze(-1)
g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei
g1_update.append(g1_edge_update)

if self.g1_bi_message:
# reveive multihead
assert self.g1_edge_linear_receive_head2 is not None
g1_edge_info_reveive2 = self.act(
self.g1_edge_linear_receive_head2(edge_info)
)
g1_edge_info_reveive2 = g1_edge_info_reveive2 * sw.unsqueeze(-1)
g1_edge_update_reveive2 = (
torch.sum(g1_edge_info_reveive2, dim=-2) / self.nnei
)
g1_update.append(g1_edge_update_reveive2)

if self.update_g1_bidirect:
# send message
assert self.g1_edge_linear_send is not None
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
g1_bi_message: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -289,6 +290,7 @@ def __init__(
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
self.compress_a = compress_a
self.g1_bi_message = g1_bi_message
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
Expand Down Expand Up @@ -396,6 +398,7 @@ def __init__(
update_g1_has_ar=self.update_g1_has_ar,
update_g2_has_arra=self.update_g2_has_arra,
compress_a=self.compress_a,
g1_bi_message=self.g1_bi_message,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,12 @@ def dpa2_repformer_args():
optional=True,
default=False,
),
Argument(
"g1_bi_message",
bool,
optional=True,
default=False,
),
Argument(
"a_dim",
int,
Expand Down

0 comments on commit a3acd36

Please sign in to comment.