Skip to content

Commit

Permalink
Update repformer_layer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent 515c534 commit bd25aa6
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,14 +798,14 @@ def _cal_hg(
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
and remains 0 beyond rcut, with shape nb x nloc x nnei.
smooth
Whether to use smoothness in processes such as attention weights calculation.
epsilon
Expand All @@ -814,27 +814,27 @@ def _cal_hg(
Returns
-------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
The transposed rotation matrix, with shape nb x nloc x 3 x ng.
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
# g: nb x nloc x nnei x ng
# h: nb x nloc x nnei x 3
# msk: nb x nloc x nnei
nb, nloc, nnei, _ = g.shape
ng = g.shape[-1]
# nf x nloc x nnei x ng
# nb x nloc x nnei x ng
g = _apply_nlist_mask(g, nlist_mask)
if not smooth:
# nf x nloc
# nb x nloc
# must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy
invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1))
# nf x nloc x 1 x 1
# nb x nloc x 1 x 1
invnnei = invnnei.unsqueeze(-1).unsqueeze(-1)
else:
g = _apply_switch(g, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nf, nloc, 1, 1), dtype=g.dtype, device=g.device
(nb, nloc, 1, 1), dtype=g.dtype, device=g.device
)
# nf x nloc x 3 x ng
# nb x nloc x 3 x ng
hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei
return hg

Expand All @@ -846,23 +846,23 @@ def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor:
Parameters
----------
hg
The transposed rotation matrix, with shape nf x nloc x 3 x ng.
The transposed rotation matrix, with shape nb x nloc x 3 x ng.
axis_neuron
Size of the submatrix.
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng)
"""
# nf x nloc x 3 x ng
nf, nloc, _, ng = hg.shape
# nf x nloc x 3 x axis
# nb x nloc x 3 x ng
nb, nloc, _, ng = hg.shape
# nb x nloc x 3 x axis
hgm = torch.split(hg, axis_neuron, dim=-1)[0]
# nf x nloc x axis_neuron x ng
# nb x nloc x axis_neuron x ng
grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1)
# nf x nloc x (axis_neuron x ng)
grrg = grrg.view(nf, nloc, axis_neuron * ng)
# nb x nloc x (axis_neuron x ng)
grrg = grrg.view(nb, nloc, axis_neuron * ng)
return grrg

def symmetrization_op(
Expand All @@ -881,14 +881,14 @@ def symmetrization_op(
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nf x nloc x nnei x ng.
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng.
h
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nf x nloc x nnei x 3.
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nf x nloc x nnei.
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
sw
The switch function, which equals 1 within the rcut_smth range, smoothly decays from 1 to 0 between rcut_smth and rcut,
and remains 0 beyond rcut, with shape nf x nloc x nnei.
and remains 0 beyond rcut, with shape nb x nloc x nnei.
axis_neuron
Size of the submatrix.
smooth
Expand All @@ -899,15 +899,15 @@ def symmetrization_op(
Returns
-------
grrg
Atomic invariant rep, with shape nf x nloc x (axis_neuron x ng)
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng)
"""
# g: nf x nloc x nnei x ng
# h: nf x nloc x nnei x 3
# msk: nf x nloc x nnei
nf, nloc, nnei, _ = g.shape
# nf x nloc x 3 x ng
# g: nb x nloc x nnei x ng
# h: nb x nloc x nnei x 3
# msk: nb x nloc x nnei
nb, nloc, nnei, _ = g.shape
# nb x nloc x 3 x ng
hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nf x nloc x (axis_neuron x ng)
# nb x nloc x (axis_neuron x ng)
grrg = self._cal_grrg(hg, axis_neuron)
return grrg

Expand Down

0 comments on commit bd25aa6

Please sign in to comment.