Skip to content

Commit

Permalink
fix(pt): fix zero inputs for LayerNorm (#4134)
Browse files Browse the repository at this point in the history
Fix #4064.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Bug Fixes**
- Improved robustness of layer normalization by handling empty input
tensors, ensuring consistent output without errors.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Sep 18, 2024
1 parent 6976fb7 commit ba9f02f
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions deepmd/pt/model/network/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ def forward(
# variance = xx.var(dim=-1, unbiased=False, keepdim=True)
# The following operation is the same as above, but will not raise error when using jit model to inference.
# See https://github.com/pytorch/pytorch/issues/85792
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
yy = (xx - mean) / torch.sqrt(variance + self.eps)
if xx.numel() > 0:
variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True)
yy = (xx - mean) / torch.sqrt(variance + self.eps)
else:
yy = xx
if self.matrix is not None and self.bias is not None:
yy = yy * self.matrix + self.bias
return yy
Expand Down

0 comments on commit ba9f02f

Please sign in to comment.