Skip to content

Commit

Permalink
add node local attention
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Jan 16, 2025
1 parent bc35e33 commit 8926664
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 0 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(
h1_dim: int = 16,
skip_stat: bool = False,
a_compress_use_split: bool = False,
update_n_has_attn: bool = False,
n_attn_hidden: int = 64,
n_attn_head: int = 4,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -109,6 +112,9 @@ def __init__(
self.h1_message_idc = h1_message_idc
self.h1_message_only_nei = h1_message_only_nei
self.h1_dim = h1_dim
self.update_n_has_attn = update_n_has_attn
self.n_attn_hidden = n_attn_hidden
self.n_attn_head = n_attn_head

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def init_subclass_params(sub_data, sub_class):
h1_message_sub_axis=self.repflow_args.h1_message_sub_axis,
h1_message_idc=self.repflow_args.h1_message_idc,
h1_message_only_nei=self.repflow_args.h1_message_only_nei,
update_n_has_attn=self.repflow_args.update_n_has_attn,
n_attn_hidden=self.repflow_args.n_attn_hidden,
n_attn_head=self.repflow_args.n_attn_head,
h1_dim=self.repflow_args.h1_dim,
skip_stat=self.repflow_args.skip_stat,
exclude_types=exclude_types,
Expand Down
35 changes: 35 additions & 0 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
child_seed,
)
from deepmd.pt.model.descriptor.repformer_layer import (
LocalAtten,
_apply_nlist_mask,
_apply_switch,
_make_nei_g1,
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(
h1_message_idc: bool = False,
h1_message_only_nei: bool = False,
h1_dim: int = 16,
update_n_has_attn: bool = False,
n_attn_hidden: int = 64,
n_attn_head: int = 4,
activation_function: str = "silu",
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -107,6 +111,9 @@ def __init__(
self.h1_message_sub_axis = h1_message_sub_axis
self.h1_message_idc = h1_message_idc
self.h1_message_only_nei = h1_message_only_nei
self.update_n_has_attn = update_n_has_attn
self.n_attn_hidden = n_attn_hidden
self.n_attn_head = n_attn_head
self.has_h1 = self.update_n_has_h1 or self.update_e_has_h1
self.precision = precision
self.seed = seed
Expand Down Expand Up @@ -195,6 +202,29 @@ def __init__(
)
)

# node local attention
if self.update_n_has_attn:
self.node_attn = LocalAtten(
self.n_dim,
self.n_attn_hidden,
self.n_attn_head,
True,
precision=precision,
seed=child_seed(seed, 6),
)
if self.update_style == "res_residual":
self.n_residual.append(
get_residual(
n_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 3),
)
)
else:
self.node_attn = None

# h1 message
if self.has_h1:
self.h1_linear = MLPLayer(
Expand Down Expand Up @@ -582,6 +612,11 @@ def forward(
else:
n_update_list.append(node_edge_update)

# node local attn
if self.update_n_has_attn:
assert self.node_attn is not None
n_update_list.append(self.node_attn(node_ebd, nei_node_ebd, nlist_mask, sw))

# h1 message
if self.has_h1:
assert h1_ext is not None
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def __init__(
h1_message_idc: bool = False,
h1_message_only_nei: bool = False,
h1_dim: int = 16,
update_n_has_attn: bool = False,
n_attn_hidden: int = 64,
n_attn_head: int = 4,
set_davg_zero: bool = True,
exclude_types: list[tuple[int, int]] = [],
env_protection: float = 0.0,
Expand Down Expand Up @@ -209,6 +212,9 @@ def __init__(
self.h1_message_idc = h1_message_idc
self.h1_message_only_nei = h1_message_only_nei
self.h1_dim = h1_dim
self.update_n_has_attn = update_n_has_attn
self.n_attn_hidden = n_attn_hidden
self.n_attn_head = n_attn_head

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -275,6 +281,9 @@ def __init__(
h1_message_idc=self.h1_message_idc,
h1_message_only_nei=self.h1_message_only_nei,
h1_dim=self.h1_dim,
update_n_has_attn=self.update_n_has_attn,
n_attn_hidden=self.n_attn_hidden,
n_attn_head=self.n_attn_head,
activation_function=self.activation_function,
update_style=self.update_style,
update_residual=self.update_residual,
Expand Down
18 changes: 18 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,24 @@ def dpa3_repflow_args():
optional=True,
default=False,
),
Argument(
"update_n_has_attn",
bool,
optional=True,
default=False,
),
Argument(
"n_attn_hidden",
int,
optional=True,
default=64,
),
Argument(
"n_attn_head",
int,
optional=True,
default=4,
),
]


Expand Down

0 comments on commit 8926664

Please sign in to comment.