diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 2b8383806b..582abf4d69 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -32,6 +32,7 @@ ) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, + DP_DTYPE_PROMOTION_STRICT, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( @@ -200,6 +201,8 @@ def forward( The output. """ ori_prec = xx.dtype + if not DP_DTYPE_PROMOTION_STRICT: + xx = xx.to(self.prec) yy = ( torch.matmul(xx, self.matrix) + self.bias if self.bias is not None @@ -214,6 +217,8 @@ def forward( yy += torch.concat([xx, xx], dim=-1) else: yy = yy + if not DP_DTYPE_PROMOTION_STRICT: + yy = yy.to(ori_prec) return yy def serialize(self) -> dict: diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 3ee0b7b54d..81dce669ff 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -15,6 +15,7 @@ ) SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" try: # only linux ncpus = len(os.sched_getaffinity(0))