Skip to content

Commit

Permalink
Update mlp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 13, 2024
1 parent b1faf56 commit abfddd8
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
DP_DTYPE_PROMOTION_STRICT,
PRECISION_DICT,
)
from deepmd.pt.utils.utils import (
Expand Down Expand Up @@ -201,7 +200,7 @@ def forward(
The output.
"""
ori_prec = xx.dtype
if not DP_DTYPE_PROMOTION_STRICT:
if not env.DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
Expand All @@ -217,7 +216,7 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
if not DP_DTYPE_PROMOTION_STRICT:
if not env.DP_DTYPE_PROMOTION_STRICT:
yy = yy.to(ori_prec)
return yy

Expand Down

0 comments on commit abfddd8

Please sign in to comment.