From 7afe438b77df6e63c88ab2a50fe83dcf7f82385f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 2 Feb 2024 17:29:43 -0500 Subject: [PATCH 1/2] pt: apply global user-set precision to pt Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/env.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index a679ccf1fa..d6d17f5f79 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -5,14 +5,19 @@ import torch from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, get_default_nthreads, set_default_nthreads, ) -PRECISION = os.environ.get("PRECISION", "float64") -GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION) -GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION) -GLOBAL_ENER_FLOAT_PRECISION = getattr(np, PRECISION) +numpy_to_torch_dtype_dict = { + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, +} + +GLOBAL_PT_FLOAT_PRECISION = numpy_to_torch_dtype_dict[GLOBAL_NP_FLOAT_PRECISION] SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) try: # only linux @@ -52,3 +57,18 @@ torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: torch.set_num_threads(intra_nthreads) + +__all__ = [ + "GLOBAL_ENER_FLOAT_PRECISION", + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_PT_FLOAT_PRECISION", + "DEFAULT_PRECISION", + "PRECISION_DICT", + "SAMPLER_RECORD", + "NUM_WORKERS", + "DEVICE", + "JIT", + "CACHE_PER_SYS", + "ENERGY_BIAS_TRAINABLE", + "LOCAL_RANK", +] From 098f02b8ed0d967ebab4ee64adeb9ca627ae8c5c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 4 Feb 2024 01:15:46 -0500 Subject: [PATCH 2/2] reuse PRECISION_DICT --- deepmd/pt/utils/env.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index d6d17f5f79..81499b5063 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -11,13 +11,6 @@ set_default_nthreads, ) -numpy_to_torch_dtype_dict = { - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, -} - -GLOBAL_PT_FLOAT_PRECISION = numpy_to_torch_dtype_dict[GLOBAL_NP_FLOAT_PRECISION] SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) try: # only linux @@ -48,6 +41,7 @@ "int32": torch.int32, "int64": torch.int64, } +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] DEFAULT_PRECISION = "float64" # throw warnings if threads not set