diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index a679ccf1fa..81499b5063 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -5,14 +5,12 @@ import torch from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, get_default_nthreads, set_default_nthreads, ) -PRECISION = os.environ.get("PRECISION", "float64") -GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION) -GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION) -GLOBAL_ENER_FLOAT_PRECISION = getattr(np, PRECISION) SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) try: # only linux @@ -43,6 +41,7 @@ "int32": torch.int32, "int64": torch.int64, } +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] DEFAULT_PRECISION = "float64" # throw warnings if threads not set @@ -52,3 +51,18 @@ torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: torch.set_num_threads(intra_nthreads) + +__all__ = [ + "GLOBAL_ENER_FLOAT_PRECISION", + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_PT_FLOAT_PRECISION", + "DEFAULT_PRECISION", + "PRECISION_DICT", + "SAMPLER_RECORD", + "NUM_WORKERS", + "DEVICE", + "JIT", + "CACHE_PER_SYS", + "ENERGY_BIAS_TRAINABLE", + "LOCAL_RANK", +]