Skip to content

Commit

Permalink
update bi g1
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 4, 2024
1 parent 7875b36 commit 0a0ba04
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 1 deletion.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def __init__(
ln_eps: Optional[float] = 1e-5,
use_undirect_g2: bool = False,
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -351,6 +352,7 @@ def __init__(
self.ln_eps = ln_eps
self.use_undirect_g2 = use_undirect_g2
self.use_undirect_a = use_undirect_a
self.update_g1_bidirect = update_g1_bidirect

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def init_subclass_params(sub_data, sub_class):
angle_only_cos=self.repformer_args.angle_only_cos,
use_undirect_g2=self.repformer_args.use_undirect_g2,
use_undirect_a=self.repformer_args.use_undirect_a,
update_g1_bidirect=self.repformer_args.update_g1_bidirect,
seed=child_seed(seed, 1),
)
self.rcsl_list = [
Expand Down
52 changes: 51 additions & 1 deletion deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def __init__(
angle_use_self_g2_padding: bool = True,
use_undirect_g2: bool = False,
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -536,6 +537,7 @@ def __init__(
self.angle_use_self_g2_padding = angle_use_self_g2_padding
self.use_undirect_g2 = use_undirect_g2
self.use_undirect_a = use_undirect_a
self.update_g1_bidirect = update_g1_bidirect

assert update_residual_init in [
"norm",
Expand Down Expand Up @@ -626,13 +628,22 @@ def __init__(
g1_dim,
precision=precision,
seed=child_seed(seed, 11),
) # need act
) # need act # receive
self.g1_edge_linear2 = MLPLayer(
g1_dim,
g1_dim,
precision=precision,
seed=child_seed(seed, 12),
) # need act
if self.update_g1_bidirect:
self.g1_edge_linear_send = MLPLayer(
self.edge_info_dim,
g1_dim,
precision=precision,
seed=child_seed(seed, 20),
) # need act # send
else:
self.g1_edge_linear_send = None
if self.update_style == "res_residual":
self.g1_residual.append(
get_residual(
Expand All @@ -643,6 +654,16 @@ def __init__(
seed=child_seed(seed, 13),
)
)
if self.update_g1_bidirect:
self.g1_residual.append(
get_residual(
g1_dim,
self.update_residual,
self.update_residual_init,
precision=precision,
seed=child_seed(seed, 21),
)
)

if not self.update_g2_has_edge:
# g2 self mlp
Expand Down Expand Up @@ -996,6 +1017,7 @@ def forward(
angle_nlist: torch.Tensor, # nf x nloc x a_nnei
angle_nlist_mask: torch.Tensor, # nf x nloc x a_nnei
angle_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
nlist_loc: Optional[torch.Tensor] = None,
):
"""
Parameters
Expand Down Expand Up @@ -1100,9 +1122,37 @@ def forward(
assert self.g1_edge_linear1 is not None
assert self.g1_edge_linear2 is not None
# nb x nloc x nnei x ng1
# receive
g1_edge_info = self.act(self.g1_edge_linear1(edge_info)) * sw.unsqueeze(-1)
g1_edge_update = torch.sum(g1_edge_info, dim=-2) / self.nnei
g1_update.append(g1_edge_update)
if self.update_g1_bidirect:
# send message
assert self.g1_edge_linear_send is not None
# nb x nloc x nnei x ng1
g1_edge_info_send = self.act(
self.g1_edge_linear_send(edge_info)
) * sw.unsqueeze(-1)
assert nlist_loc is not None
# nb x (nloc+1) x nnei x ng1
scattered_message = torch.zeros(
size=[nb, nloc + 1, nnei, self.g1_dim],
device=g1_edge_info_send.device,
dtype=g1_edge_info_send.dtype,
)
# nb x nloc x nnei x ng1
scatter_index = nlist_loc.unsqueeze(-1).expand(-1, -1, -1, self.g1_dim)
# nb x nloc x nnei x ng1
scattered_message = torch.scatter_reduce(
scattered_message,
dim=1,
index=scatter_index,
src=g1_edge_info_send,
reduce="sum",
)[:, :-1, :, :]
# nb x nloc x ng1
g1_edge_update_send = torch.sum(scattered_message, dim=-2) / self.nnei
g1_update.append(g1_edge_update_send)

assert self.linear2 is not None
if not self.update_g2_has_edge:
Expand Down
10 changes: 10 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
angle_only_cos: bool = False,
use_undirect_g2: bool = False,
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -273,6 +274,7 @@ def __init__(
self.angle_only_cos = angle_only_cos
self.use_undirect_g2 = use_undirect_g2
self.use_undirect_a = use_undirect_a
self.update_g1_bidirect = update_g1_bidirect
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
Expand Down Expand Up @@ -364,6 +366,7 @@ def __init__(
g1_out_mlp=self.g1_out_mlp,
use_undirect_g2=self.use_undirect_g2,
use_undirect_a=self.use_undirect_a,
update_g1_bidirect=self.update_g1_bidirect,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -576,9 +579,15 @@ def forward(
# nb x nall x ng1
if comm_dict is None:
assert mapping is not None
nlist_loc = torch.gather(
mapping, index=nlist.reshape(nframes, -1), dim=1
).reshape(nframes, nloc, self.nnei)
nlist_loc = torch.where(nlist_mask, nlist_loc, nloc)
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
else:
nlist_loc = None
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
Expand Down Expand Up @@ -649,6 +658,7 @@ def forward(
angle_nlist,
angle_nlist_mask,
angle_sw,
nlist_loc=nlist_loc,
)

# nb x nloc x 3 x ng2
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,12 @@ def dpa2_repformer_args():
optional=True,
default=False,
),
Argument(
"update_g1_bidirect",
bool,
optional=True,
default=False,
),
Argument(
"update_g2_has_a",
bool,
Expand Down

0 comments on commit 0a0ba04

Please sign in to comment.