Skip to content

Commit

Permalink
revert 'nf' to 'nb'
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent d2bcdbf commit 385e1f7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 79 deletions.
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
Returns
-------
descriptor: torch.Tensor
the descriptor of shape nf x nloc x g1_dim.
the descriptor of shape nb x nloc x g1_dim.
invariant single-atom representation.
g2: torch.Tensor
invariant pair-atom representation.
Expand Down
144 changes: 72 additions & 72 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,26 @@ def _make_nei_g1(
Parameters
----------
g1_ext
Extended atomic invariant rep, with shape nf x nall x ng1.
Extended atomic invariant rep, with shape nb x nall x ng1.
nlist
Neighbor list, with shape nf x nloc x nnei.
Neighbor list, with shape nb x nloc x nnei.
Returns
-------
gg1: torch.Tensor
Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1.
Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1.
"""
# nlist: nf x nloc x nnei
nf, nloc, nnei = nlist.shape
# g1_ext: nf x nall x ng1
# nlist: nb x nloc x nnei
nb, nloc, nnei = nlist.shape
# g1_ext: nb x nall x ng1
ng1 = g1_ext.shape[-1]
# index: nf x (nloc x nnei) x ng1
index = nlist.reshape(nf, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1)
# gg1 : nf x (nloc x nnei) x ng1
# index: nb x (nloc x nnei) x ng1
index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1)
# gg1 : nb x (nloc x nnei) x ng1
gg1 = torch.gather(g1_ext, dim=1, index=index)
# gg1 : nf x nloc x nnei x ng1
gg1 = gg1.view(nf, nloc, nnei, ng1)
# gg1 : nb x nloc x nnei x ng1
gg1 = gg1.view(nb, nloc, nnei, ng1)
return gg1


Expand Down Expand Up @@ -291,34 +291,34 @@ def __init__(

def forward(
self,
g2: torch.Tensor, # nf x nloc x nnei x ng2
h2: torch.Tensor, # nf x nloc x nnei x 3
nlist_mask: torch.Tensor, # nf x nloc x nnei
sw: torch.Tensor, # nf x nloc x nnei
g2: torch.Tensor, # nb x nloc x nnei x ng2
h2: torch.Tensor, # nb x nloc x nnei x 3
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
(
nf,
nb,
nloc,
nnei,
_,
) = g2.shape
nd, nh = self.hidden_dim, self.head_num
# nf x nloc x nnei x nd x (nh x 2)
g2qk = self.mapqk(g2).view(nf, nloc, nnei, nd, nh * 2)
# nf x nloc x (nh x 2) x nnei x nd
# nb x nloc x nnei x nd x (nh x 2)
g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2)
# nb x nloc x (nh x 2) x nnei x nd
g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3))
# nf x nloc x nh x nnei x nd
# nb x nloc x nh x nnei x nd
g2q, g2k = torch.split(g2qk, nh, dim=2)
# g2q = torch.nn.functional.normalize(g2q, dim=-1)
# g2k = torch.nn.functional.normalize(g2k, dim=-1)
# nf x nloc x nh x nnei x nnei
# nb x nloc x nh x nnei x nnei
attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5
if self.has_gate:
gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3)
attnw = attnw * gate
# mask the attenmap, nf x nloc x 1 x 1 x nnei
# mask the attenmap, nb x nloc x 1 x 1 x nnei
attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2)
# mask the attenmap, nf x nloc x 1 x nnei x 1
# mask the attenmap, nb x nloc x 1 x nnei x 1
attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1)
if self.smooth:
attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[
Expand All @@ -334,19 +334,19 @@ def forward(
attnw_mask,
0.0,
)
# nf x nloc x nh x nnei x nnei
# nb x nloc x nh x nnei x nnei
attnw = attnw.masked_fill(
attnw_mask_c,
0.0,
)
if self.smooth:
attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :]
# nf x nloc x nnei x nnei
# nb x nloc x nnei x nnei
h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5
# nf x nloc x nh x nnei x nnei
# nb x nloc x nh x nnei x nnei
ret = attnw * h2h2t[:, :, None, :, :]
# ret = torch.softmax(g2qk, dim=-1)
# nf x nloc x nnei x nnei x nh
# nb x nloc x nnei x nnei x nh
ret = torch.permute(ret, (0, 1, 3, 4, 2))
return ret

Expand Down Expand Up @@ -561,32 +561,32 @@ def __init__(

def forward(
self,
g1: torch.Tensor, # nf x nloc x ng1
gg1: torch.Tensor, # nf x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nf x nloc x nnei
sw: torch.Tensor, # nf x nloc x nnei
g1: torch.Tensor, # nb x nloc x ng1
gg1: torch.Tensor, # nb x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
nf, nloc, nnei = nlist_mask.shape
nb, nloc, nnei = nlist_mask.shape
ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num
assert ni == g1.shape[-1]
assert ni == gg1.shape[-1]
# nf x nloc x nd x nh
g1q = self.mapq(g1).view(nf, nloc, nd, nh)
# nf x nloc x nh x nd
# nb x nloc x nd x nh
g1q = self.mapq(g1).view(nb, nloc, nd, nh)
# nb x nloc x nh x nd
g1q = torch.permute(g1q, (0, 1, 3, 2))
# nf x nloc x nnei x (nd+ni) x nh
gg1kv = self.mapkv(gg1).view(nf, nloc, nnei, nd + ni, nh)
# nb x nloc x nnei x (nd+ni) x nh
gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh)
gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3))
# nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1
# nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1
gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1)

# nf x nloc x nh x 1 x nnei
# nb x nloc x nh x 1 x nnei
attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5
# nf x nloc x nh x nnei
# nb x nloc x nh x nnei
attnw = attnw.squeeze(-2)
# mask the attenmap, nf x nloc x 1 x nnei
# mask the attenmap, nb x nloc x 1 x nnei
attnw_mask = ~nlist_mask.unsqueeze(-2)
# nf x nloc x nh x nnei
# nb x nloc x nh x nnei
if self.smooth:
attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift
else:
Expand All @@ -602,11 +602,11 @@ def forward(
if self.smooth:
attnw = attnw * sw.unsqueeze(-2)

# nf x nloc x nh x ng1
# nb x nloc x nh x ng1
ret = (
torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nf, nloc, nh * ni)
torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni)
)
# nf x nloc x ng1
# nb x nloc x ng1
ret = self.head_map(ret)
return ret

Expand Down Expand Up @@ -878,63 +878,63 @@ def _update_g1_conv(
Parameters
----------
gg1
Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1.
Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1.
g2
Pair invariant rep, with shape nf x nloc x nnei x ng2.
Pair invariant rep, with shape nb x nloc x nnei x ng2.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
and remains 0 beyond rcut, with shape nb x nloc x nnei.
"""
assert self.proj_g1g2 is not None
nf, nloc, nnei, _ = g2.shape
nb, nloc, nnei, _ = g2.shape
ng1 = gg1.shape[-1]
ng2 = g2.shape[-1]
# gg1 : nf x nloc x nnei x ng2
gg1 = self.proj_g1g2(gg1).view(nf, nloc, nnei, ng2)
# nf x nloc x nnei x ng2
# gg1 : nb x nloc x nnei x ng2
gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2)
# nb x nloc x nnei x ng2
gg1 = _apply_nlist_mask(gg1, nlist_mask)
if not self.smooth:
# normalized by number of neighbors, not smooth
# nf x nloc x 1
# nb x nloc x 1
# must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy
invnnei = 1.0 / (
self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1)
).unsqueeze(-1)
else:
gg1 = _apply_switch(gg1, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nf, nloc, 1), dtype=gg1.dtype, device=gg1.device
(nb, nloc, 1), dtype=gg1.dtype, device=gg1.device
)
# nf x nloc x ng2
# nb x nloc x ng2
g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei
return g1_11

def _update_g2_g1g1(
self,
g1: torch.Tensor, # nf x nloc x ng1
gg1: torch.Tensor, # nf x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nf x nloc x nnei
sw: torch.Tensor, # nf x nloc x nnei
g1: torch.Tensor, # nb x nloc x ng1
gg1: torch.Tensor, # nb x nloc x nnei x ng1
nlist_mask: torch.Tensor, # nb x nloc x nnei
sw: torch.Tensor, # nb x nloc x nnei
) -> torch.Tensor:
"""
Update the g2 using element-wise dot g1_i * g1_j.
Parameters
----------
g1
Atomic invariant rep, with shape nf x nloc x ng1.
Atomic invariant rep, with shape nb x nloc x ng1.
gg1
Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1.
Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
and remains 0 beyond rcut, with shape nb x nloc x nnei.
"""
ret = g1.unsqueeze(-2) * gg1
# nf x nloc x nnei x ng1
# nb x nloc x nnei x ng1
ret = _apply_nlist_mask(ret, nlist_mask)
if self.smooth:
ret = _apply_switch(ret, sw)
Expand Down Expand Up @@ -972,11 +972,11 @@ def forward(
or self.update_g2_has_g1g1
)

nf, nloc, nnei, _ = g2.shape
nb, nloc, nnei, _ = g2.shape
nall = g1_ext.shape[1]
g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1)
assert (nf, nloc) == g1.shape[:2]
assert (nf, nloc, nnei) == h2.shape[:3]
assert (nb, nloc) == g1.shape[:2]
assert (nb, nloc, nnei) == h2.shape[:3]

g2_update: List[torch.Tensor] = [g2]
h2_update: List[torch.Tensor] = [h2]
Expand All @@ -991,7 +991,7 @@ def forward(
if self.update_chnnl_2:
# mlp(g2)
assert self.linear2 is not None
# nf x nloc x nnei x ng2
# nb x nloc x nnei x ng2
g2_1 = self.act(self.linear2(g2))
g2_update.append(g2_1)

Expand All @@ -1006,13 +1006,13 @@ def forward(
if self.update_g2_has_attn or self.update_h2:
# gated_attention(g2, h2)
assert self.attn2g_map is not None
# nf x nloc x nnei x nnei x nh
# nb x nloc x nnei x nnei x nh
AAg = self.attn2g_map(g2, h2, nlist_mask, sw)

if self.update_g2_has_attn:
assert self.attn2_mh_apply is not None
assert self.attn2_lm is not None
# nf x nloc x nnei x ng2
# nb x nloc x nnei x ng2
g2_2 = self.attn2_mh_apply(AAg, g2)
g2_2 = self.attn2_lm(g2_2)
g2_update.append(g2_2)
Expand Down Expand Up @@ -1052,7 +1052,7 @@ def forward(
)
)

# nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)]
# nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)]
# conv grrg drrd
g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1)))
g1_update.append(g1_1)
Expand Down
12 changes: 6 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,14 @@ def forward(
atype_embd = extended_atype_embd
assert isinstance(atype_embd, torch.Tensor) # for jit
g1 = self.act(atype_embd)
# nf x nloc x nnei x 1, nf x nloc x nnei x 3
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
else:
g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff
g2 = g2 / self.rcut
h2 = h2 / self.rcut
# nf x nloc x nnei x ng2
# nb x nloc x nnei x ng2
g2 = self.act(self.g2_embd(g2))

# set all padding positions to index of 0
Expand All @@ -448,8 +448,8 @@ def forward(
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nf x nloc x ng1
# g1_ext: nf x nall x ng1
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
Expand Down Expand Up @@ -485,9 +485,9 @@ def forward(
sw,
)

# nf x nloc x 3 x ng2
# nb x nloc x 3 x ng2
h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon)
# (nf x nloc) x ng2 x 3
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw
Expand Down

0 comments on commit 385e1f7

Please sign in to comment.