diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 582abf4d69..a5ac086770 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -32,7 +32,6 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, - DP_DTYPE_PROMOTION_STRICT, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( @@ -201,7 +200,7 @@ def forward( The output. """ ori_prec = xx.dtype - if not DP_DTYPE_PROMOTION_STRICT: + if not env.DP_DTYPE_PROMOTION_STRICT: xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias @@ -217,7 +216,7 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy - if not DP_DTYPE_PROMOTION_STRICT: + if not env.DP_DTYPE_PROMOTION_STRICT: yy = yy.to(ori_prec) return yy