diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index d6d17f5f79..81499b5063 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -11,13 +11,6 @@ set_default_nthreads, ) -numpy_to_torch_dtype_dict = { - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, -} - -GLOBAL_PT_FLOAT_PRECISION = numpy_to_torch_dtype_dict[GLOBAL_NP_FLOAT_PRECISION] SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) try: # only linux @@ -48,6 +41,7 @@ "int32": torch.int32, "int64": torch.int64, } +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] DEFAULT_PRECISION = "float64" # throw warnings if threads not set