diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index f23456f8e9..2e9e228bc9 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -255,7 +255,7 @@ def __init__( Returns ------- descriptor: torch.Tensor - the descriptor of shape nf x nloc x g1_dim. + the descriptor of shape nb x nloc x g1_dim. invariant single-atom representation. g2: torch.Tensor invariant pair-atom representation. diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 3f304cc7a0..8af81520dd 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -78,26 +78,26 @@ def _make_nei_g1( Parameters ---------- g1_ext - Extended atomic invariant rep, with shape nf x nall x ng1. + Extended atomic invariant rep, with shape nb x nall x ng1. nlist - Neighbor list, with shape nf x nloc x nnei. + Neighbor list, with shape nb x nloc x nnei. Returns ------- gg1: torch.Tensor - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. """ - # nlist: nf x nloc x nnei - nf, nloc, nnei = nlist.shape - # g1_ext: nf x nall x ng1 + # nlist: nb x nloc x nnei + nb, nloc, nnei = nlist.shape + # g1_ext: nb x nall x ng1 ng1 = g1_ext.shape[-1] - # index: nf x (nloc x nnei) x ng1 - index = nlist.reshape(nf, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) - # gg1 : nf x (nloc x nnei) x ng1 + # index: nb x (nloc x nnei) x ng1 + index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, ng1) + # gg1 : nb x (nloc x nnei) x ng1 gg1 = torch.gather(g1_ext, dim=1, index=index) - # gg1 : nf x nloc x nnei x ng1 - gg1 = gg1.view(nf, nloc, nnei, ng1) + # gg1 : nb x nloc x nnei x ng1 + gg1 = gg1.view(nb, nloc, nnei, ng1) return gg1 @@ -291,34 +291,34 @@ def __init__( def forward( self, - g2: torch.Tensor, # nf x nloc x nnei x ng2 - h2: torch.Tensor, # nf x nloc x nnei x 3 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g2: torch.Tensor, # nb x nloc x nnei x ng2 + h2: torch.Tensor, # nb x nloc x nnei x 3 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: ( - nf, + nb, nloc, nnei, _, ) = g2.shape nd, nh = self.hidden_dim, self.head_num - # nf x nloc x nnei x nd x (nh x 2) - g2qk = self.mapqk(g2).view(nf, nloc, nnei, nd, nh * 2) - # nf x nloc x (nh x 2) x nnei x nd + # nb x nloc x nnei x nd x (nh x 2) + g2qk = self.mapqk(g2).view(nb, nloc, nnei, nd, nh * 2) + # nb x nloc x (nh x 2) x nnei x nd g2qk = torch.permute(g2qk, (0, 1, 4, 2, 3)) - # nf x nloc x nh x nnei x nd + # nb x nloc x nh x nnei x nd g2q, g2k = torch.split(g2qk, nh, dim=2) # g2q = torch.nn.functional.normalize(g2q, dim=-1) # g2k = torch.nn.functional.normalize(g2k, dim=-1) - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei attnw = torch.matmul(g2q, torch.transpose(g2k, -1, -2)) / nd**0.5 if self.has_gate: gate = torch.matmul(h2, torch.transpose(h2, -1, -2)).unsqueeze(-3) attnw = attnw * gate - # mask the attenmap, nf x nloc x 1 x 1 x nnei + # mask the attenmap, nb x nloc x 1 x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(2).unsqueeze(2) - # mask the attenmap, nf x nloc x 1 x nnei x 1 + # mask the attenmap, nb x nloc x 1 x nnei x 1 attnw_mask_c = ~nlist_mask.unsqueeze(2).unsqueeze(-1) if self.smooth: attnw = (attnw + self.attnw_shift) * sw[:, :, None, :, None] * sw[ @@ -334,19 +334,19 @@ def forward( attnw_mask, 0.0, ) - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei attnw = attnw.masked_fill( attnw_mask_c, 0.0, ) if self.smooth: attnw = attnw * sw[:, :, None, :, None] * sw[:, :, None, None, :] - # nf x nloc x nnei x nnei + # nb x nloc x nnei x nnei h2h2t = torch.matmul(h2, torch.transpose(h2, -1, -2)) / 3.0**0.5 - # nf x nloc x nh x nnei x nnei + # nb x nloc x nh x nnei x nnei ret = attnw * h2h2t[:, :, None, :, :] # ret = torch.softmax(g2qk, dim=-1) - # nf x nloc x nnei x nnei x nh + # nb x nloc x nnei x nnei x nh ret = torch.permute(ret, (0, 1, 3, 4, 2)) return ret @@ -561,32 +561,32 @@ def __init__( def forward( self, - g1: torch.Tensor, # nf x nloc x ng1 - gg1: torch.Tensor, # nf x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: - nf, nloc, nnei = nlist_mask.shape + nb, nloc, nnei = nlist_mask.shape ni, nd, nh = self.input_dim, self.hidden_dim, self.head_num assert ni == g1.shape[-1] assert ni == gg1.shape[-1] - # nf x nloc x nd x nh - g1q = self.mapq(g1).view(nf, nloc, nd, nh) - # nf x nloc x nh x nd + # nb x nloc x nd x nh + g1q = self.mapq(g1).view(nb, nloc, nd, nh) + # nb x nloc x nh x nd g1q = torch.permute(g1q, (0, 1, 3, 2)) - # nf x nloc x nnei x (nd+ni) x nh - gg1kv = self.mapkv(gg1).view(nf, nloc, nnei, nd + ni, nh) + # nb x nloc x nnei x (nd+ni) x nh + gg1kv = self.mapkv(gg1).view(nb, nloc, nnei, nd + ni, nh) gg1kv = torch.permute(gg1kv, (0, 1, 4, 2, 3)) - # nf x nloc x nh x nnei x nd, nf x nloc x nh x nnei x ng1 + # nb x nloc x nh x nnei x nd, nb x nloc x nh x nnei x ng1 gg1k, gg1v = torch.split(gg1kv, [nd, ni], dim=-1) - # nf x nloc x nh x 1 x nnei + # nb x nloc x nh x 1 x nnei attnw = torch.matmul(g1q.unsqueeze(-2), torch.transpose(gg1k, -1, -2)) / nd**0.5 - # nf x nloc x nh x nnei + # nb x nloc x nh x nnei attnw = attnw.squeeze(-2) - # mask the attenmap, nf x nloc x 1 x nnei + # mask the attenmap, nb x nloc x 1 x nnei attnw_mask = ~nlist_mask.unsqueeze(-2) - # nf x nloc x nh x nnei + # nb x nloc x nh x nnei if self.smooth: attnw = (attnw + self.attnw_shift) * sw.unsqueeze(-2) - self.attnw_shift else: @@ -602,11 +602,11 @@ def forward( if self.smooth: attnw = attnw * sw.unsqueeze(-2) - # nf x nloc x nh x ng1 + # nb x nloc x nh x ng1 ret = ( - torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nf, nloc, nh * ni) + torch.matmul(attnw.unsqueeze(-2), gg1v).squeeze(-2).view(nb, nloc, nh * ni) ) - # nf x nloc x ng1 + # nb x nloc x ng1 ret = self.head_map(ret) return ret @@ -878,26 +878,26 @@ def _update_g1_conv( Parameters ---------- gg1 - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. g2 - Pair invariant rep, with shape nf x nloc x nnei x ng2. + Pair invariant rep, with shape nb x nloc x nnei x ng2. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. """ assert self.proj_g1g2 is not None - nf, nloc, nnei, _ = g2.shape + nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nf, nloc, nnei, ng2) - # nf x nloc x nnei x ng2 + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + # nb x nloc x nnei x ng2 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth - # nf x nloc x 1 + # nb x nloc x 1 # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy invnnei = 1.0 / ( self.epsilon + torch.sum(nlist_mask.type_as(gg1), dim=-1) @@ -905,18 +905,18 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * torch.ones( - (nf, nloc, 1), dtype=gg1.dtype, device=gg1.device + (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) - # nf x nloc x ng2 + # nb x nloc x ng2 g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 def _update_g2_g1g1( self, - g1: torch.Tensor, # nf x nloc x ng1 - gg1: torch.Tensor, # nf x nloc x nnei x ng1 - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # nf x nloc x nnei + g1: torch.Tensor, # nb x nloc x ng1 + gg1: torch.Tensor, # nb x nloc x nnei x ng1 + nlist_mask: torch.Tensor, # nb x nloc x nnei + sw: torch.Tensor, # nb x nloc x nnei ) -> torch.Tensor: """ Update the g2 using element-wise dot g1_i * g1_j. @@ -924,17 +924,17 @@ def _update_g2_g1g1( Parameters ---------- g1 - Atomic invariant rep, with shape nf x nloc x ng1. + Atomic invariant rep, with shape nb x nloc x ng1. gg1 - Neighbor-wise atomic invariant rep, with shape nf x nloc x nnei x ng1. + Neighbor-wise atomic invariant rep, with shape nb x nloc x nnei x ng1. nlist_mask - Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei. + Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei. sw The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut, - and remains 0 beyond rcut, with shape nf x nloc x nnei. + and remains 0 beyond rcut, with shape nb x nloc x nnei. """ ret = g1.unsqueeze(-2) * gg1 - # nf x nloc x nnei x ng1 + # nb x nloc x nnei x ng1 ret = _apply_nlist_mask(ret, nlist_mask) if self.smooth: ret = _apply_switch(ret, sw) @@ -972,11 +972,11 @@ def forward( or self.update_g2_has_g1g1 ) - nf, nloc, nnei, _ = g2.shape + nb, nloc, nnei, _ = g2.shape nall = g1_ext.shape[1] g1, _ = torch.split(g1_ext, [nloc, nall - nloc], dim=1) - assert (nf, nloc) == g1.shape[:2] - assert (nf, nloc, nnei) == h2.shape[:3] + assert (nb, nloc) == g1.shape[:2] + assert (nb, nloc, nnei) == h2.shape[:3] g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] @@ -991,7 +991,7 @@ def forward( if self.update_chnnl_2: # mlp(g2) assert self.linear2 is not None - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2_1 = self.act(self.linear2(g2)) g2_update.append(g2_1) @@ -1006,13 +1006,13 @@ def forward( if self.update_g2_has_attn or self.update_h2: # gated_attention(g2, h2) assert self.attn2g_map is not None - # nf x nloc x nnei x nnei x nh + # nb x nloc x nnei x nnei x nh AAg = self.attn2g_map(g2, h2, nlist_mask, sw) if self.update_g2_has_attn: assert self.attn2_mh_apply is not None assert self.attn2_lm is not None - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2_2 = self.attn2_mh_apply(AAg, g2) g2_2 = self.attn2_lm(g2_2) g2_update.append(g2_2) @@ -1052,7 +1052,7 @@ def forward( ) ) - # nf x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] + # nb x nloc x [ng1+ng2+(axisxng2)+(axisxng1)] # conv grrg drrd g1_1 = self.act(self.linear1(torch.cat(g1_mlp, dim=-1))) g1_update.append(g1_1) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 77f95f4b91..b8a24945c0 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -428,14 +428,14 @@ def forward( atype_embd = extended_atype_embd assert isinstance(atype_embd, torch.Tensor) # for jit g1 = self.act(atype_embd) - # nf x nloc x nnei x 1, nf x nloc x nnei x 3 + # nb x nloc x nnei x 1, nb x nloc x nnei x 3 if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut - # nf x nloc x nnei x ng2 + # nb x nloc x nnei x ng2 g2 = self.act(self.g2_embd(g2)) # set all padding positions to index of 0 @@ -448,8 +448,8 @@ def forward( mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) ) for idx, ll in enumerate(self.layers): - # g1: nf x nloc x ng1 - # g1_ext: nf x nall x ng1 + # g1: nb x nloc x ng1 + # g1_ext: nb x nall x ng1 if comm_dict is None: assert mapping is not None g1_ext = torch.gather(g1, 1, mapping) @@ -485,9 +485,9 @@ def forward( sw, ) - # nf x nloc x 3 x ng2 + # nb x nloc x 3 x ng2 h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) - # (nf x nloc) x ng2 x 3 + # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.view(-1, nloc, self.dim_emb, 3), sw