Skip to content

Commit

Permalink
pt: apply global user set precision to pt (#3220)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 5, 2024
1 parent 22197f5 commit 7f5d67c
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
import torch

from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
get_default_nthreads,
set_default_nthreads,
)

PRECISION = os.environ.get("PRECISION", "float64")
GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION)
GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION)
GLOBAL_ENER_FLOAT_PRECISION = getattr(np, PRECISION)
SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
try:
# only linux
Expand Down Expand Up @@ -43,6 +41,7 @@
"int32": torch.int32,
"int64": torch.int64,
}
GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
DEFAULT_PRECISION = "float64"

# throw warnings if threads not set
Expand All @@ -52,3 +51,18 @@
torch.set_num_interop_threads(inter_nthreads)
if intra_nthreads > 0:
torch.set_num_threads(intra_nthreads)

__all__ = [
"GLOBAL_ENER_FLOAT_PRECISION",
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_PT_FLOAT_PRECISION",
"DEFAULT_PRECISION",
"PRECISION_DICT",
"SAMPLER_RECORD",
"NUM_WORKERS",
"DEVICE",
"JIT",
"CACHE_PER_SYS",
"ENERGY_BIAS_TRAINABLE",
"LOCAL_RANK",
]

0 comments on commit 7f5d67c

Please sign in to comment.