Skip to content

Commit

Permalink
fix(pt): fix zero inputs for LayerNorm
Browse files Browse the repository at this point in the history
Fix deepmodeling#4064.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Sep 17, 2024
1 parent 96ed5df commit e15b97b
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions deepmd/pt/model/network/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def forward(
# variance = xx.var(dim=-1, unbiased=False, keepdim=True)
# The following operation is the same as above, but will not raise error when using jit model to inference.
# See https://github.com/pytorch/pytorch/issues/85792
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
yy = (xx - mean) / torch.sqrt(variance + self.eps)
if xx.numel() > 0:
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
yy = (xx - mean) / torch.sqrt(variance + self.eps)
else:
yy = xx
if self.matrix is not None and self.bias is not None:
yy = yy * self.matrix + self.bias
return yy
Expand Down

0 comments on commit e15b97b

Please sign in to comment.