Skip to content

Commit

Permalink
add optional cast
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Nov 13, 2024
1 parent f5f7370 commit b1faf56
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
5 changes: 5 additions & 0 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
DP_DTYPE_PROMOTION_STRICT,
PRECISION_DICT,
)
from deepmd.pt.utils.utils import (
Expand Down Expand Up @@ -200,6 +201,8 @@ def forward(
The output.
"""
ori_prec = xx.dtype
if not DP_DTYPE_PROMOTION_STRICT:
xx = xx.to(self.prec)
yy = (
torch.matmul(xx, self.matrix) + self.bias
if self.bias is not None
Expand All @@ -214,6 +217,8 @@ def forward(
yy += torch.concat([xx, xx], dim=-1)
else:
yy = yy
if not DP_DTYPE_PROMOTION_STRICT:
yy = yy.to(ori_prec)
return yy

def serialize(self) -> dict:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1"

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable 'DP_DTYPE_PROMOTION_STRICT' is not used.
try:
# only linux
ncpus = len(os.sched_getaffinity(0))
Expand Down

0 comments on commit b1faf56

Please sign in to comment.