Skip to content

Commit

Permalink
add update_g2_has_ar and update_g1_has_ar
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 16, 2024
1 parent 08452e5 commit 2aa6ca2
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 1 deletion.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(
pre_ln: bool = False,
no_repinit: bool = False,
g1_mess_mulmlp: bool = False,
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -352,6 +355,9 @@ def __init__(
self.angle_only_cos = angle_only_cos
self.pre_ln = pre_ln
self.no_repinit = no_repinit
self.update_g2_has_ar = update_g2_has_ar
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def init_subclass_params(sub_data, sub_class):
pipeline_update=self.repformer_args.pipeline_update,
pre_ln=self.repformer_args.pre_ln,
g1_mess_mulmlp=self.repformer_args.g1_mess_mulmlp,
update_g2_has_ar=self.repformer_args.update_g2_has_ar,
update_g1_has_ar=self.repformer_args.update_g1_has_ar,
update_g2_has_arra=self.repformer_args.update_g2_has_arra,
seed=child_seed(seed, 1),
)
self.no_repinit = self.repformer_args.no_repinit
Expand Down
102 changes: 102 additions & 0 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ def __init__(
pipeline_update: bool = False,
pre_ln: bool = False,
g1_mess_mulmlp: bool = False,
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -544,6 +547,9 @@ def __init__(
self.update_g1_bidirect = update_g1_bidirect
self.pipeline_update = pipeline_update
self.g1_mess_mulmlp = g1_mess_mulmlp
self.update_g2_has_ar = update_g2_has_ar
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
self.prec = PRECISION_DICT[precision]
self.g1_layernorm = None
self.g2_layernorm = None
Expand Down Expand Up @@ -693,6 +699,27 @@ def __init__(
)
)

# angle for g1
if self.has_angle and self.update_g1_has_ar:
self.g1_angle_linear = MLPLayer(
self.a_dim,
g1_dim,
precision=precision,
seed=child_seed(seed, 13),
) # need act
if self.update_style == "res_residual":
self.g1_residual.append(
get_residual(
g1_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 14),
)
)
else:
self.g1_angle_linear = None

if not self.update_g2_has_edge:
# g2 self mlp
self.linear2 = MLPLayer(
Expand Down Expand Up @@ -752,6 +779,27 @@ def __init__(
)
)

# angle for g2
if self.has_angle and self.update_g2_has_ar:
self.g2_angle_linear_ar = MLPLayer(
self.a_dim,
g2_dim,
precision=precision,
seed=child_seed(seed, 21),
) # need act
if self.update_style == "res_residual":
self.g2_residual.append(
get_residual(
g2_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 22),
)
)
else:
self.g2_angle_linear_ar = None

if self.has_angle:
if self.update_style == "res_layer":
self.angle_layernorm = nn.LayerNorm(
Expand Down Expand Up @@ -1053,6 +1101,7 @@ def forward(
angle_nlist_mask: torch.Tensor, # nf x nloc x a_nnei
angle_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
nlist_loc: Optional[torch.Tensor] = None,
cosine_ij: Optional[torch.Tensor] = None,
):
"""
Parameters
Expand Down Expand Up @@ -1198,6 +1247,24 @@ def forward(
g1_edge_update_send = torch.sum(scattered_message, dim=-2) / self.nnei
g1_update.append(g1_edge_update_send)

if self.has_angle and (self.update_g1_has_ar or self.update_g2_has_ar):
assert cosine_ij is not None
assert angle_embed is not None
# nb x nloc x a_nnei x a_nnei x a
angle_ar = cosine_ij.unsqueeze(-1) * angle_embed
else:
angle_ar = None

# angle for g1
if self.has_angle and self.update_g1_has_ar:
assert self.g1_angle_linear is not None
assert angle_ar is not None
# nb x nloc x g1_dim
g1_ar = self.act(self.g1_angle_linear(angle_ar)).sum(-2).sum(-2) / (
float(self.a_sel) * float(self.a_sel)
)
g1_update.append(g1_ar)

# update g1 for pipeline_update
g1_new = self.list_update(g1_update, "g1")
if self.pipeline_update:
Expand Down Expand Up @@ -1242,6 +1309,41 @@ def forward(
g2_2 = self.attn2_lm(g2_2)
g2_update.append(g2_2)

# angle for g2
if self.has_angle and self.update_g2_has_ar:
assert self.g2_angle_linear_ar is not None
assert angle_ar is not None
# nb x nloc x a_nnei x g2_dim
g2_ar = self.act(self.g2_angle_linear_ar(angle_ar)).sum(-2) / float(
self.a_sel
)
# nb x nloc x nnei x g2
padding_g2_ar = torch.concat(
[
g2_ar,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel, self.g2_dim],
dtype=g2.dtype,
device=g2.device,
),
],
dim=2,
)
if self.angle_use_self_g2_padding:
full_mask = torch.concat(
[
angle_nlist_mask,
torch.zeros(
[nb, nloc, self.nnei - self.a_sel],
dtype=angle_nlist_mask.dtype,
device=angle_nlist_mask.device,
),
],
dim=-1,
)
padding_g2_ar = torch.where(full_mask.unsqueeze(-1), padding_g2_ar, g2)
g2_update.append(padding_g2_ar)

if self.has_angle:
if self.pre_ln:
assert self.angle_layernorm is not None
Expand Down
12 changes: 11 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def __init__(
pipeline_update: bool = False,
pre_ln: bool = False,
g1_mess_mulmlp: bool = False,
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -281,6 +284,9 @@ def __init__(
self.pipeline_update = pipeline_update
self.pre_ln = pre_ln
self.g1_mess_mulmlp = g1_mess_mulmlp
self.update_g2_has_ar = update_g2_has_ar
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
Expand Down Expand Up @@ -384,6 +390,9 @@ def __init__(
pipeline_update=self.pipeline_update,
pre_ln=self.pre_ln,
g1_mess_mulmlp=self.g1_mess_mulmlp,
update_g2_has_ar=self.update_g2_has_ar,
update_g1_has_ar=self.update_g1_has_ar,
update_g2_has_arra=self.update_g2_has_arra,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -556,7 +565,7 @@ def forward(
]
angle_nlist = nlist[:, :, : self.a_sel]
angle_nlist = torch.where(a_dist_mask, angle_nlist, -1)
_, angle_diff, angle_sw = prod_env_mat(
amatrix, angle_diff, angle_sw = prod_env_mat(
extended_coord,
angle_nlist,
atype,
Expand Down Expand Up @@ -676,6 +685,7 @@ def forward(
angle_nlist_mask,
angle_sw,
nlist_loc=nlist_loc,
cosine_ij=cosine_ij,
)

# nb x nloc x 3 x ng2
Expand Down
18 changes: 18 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,24 @@ def dpa2_repformer_args():
optional=True,
default=False,
),
Argument(
"update_g2_has_ar",
bool,
optional=True,
default=False,
),
Argument(
"update_g1_has_ar",
bool,
optional=True,
default=False,
),
Argument(
"update_g2_has_arra",
bool,
optional=True,
default=False,
),
Argument(
"pipeline_update",
bool,
Expand Down

0 comments on commit 2aa6ca2

Please sign in to comment.