From 73b588d12c8660d7c056276896b2bfd275374f09 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 30 Jan 2024 10:04:18 +0800 Subject: [PATCH] reverse map for dtypes. add uts --- deepmd/pt/utils/utils.py | 36 +++++++++++++++++++++-------------- source/tests/pt/test_utils.py | 31 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 14 deletions(-) create mode 100644 source/tests/pt/test_utils.py diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 516cbbdba6..e83e12f608 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -54,22 +54,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def to_numpy_array( xx: torch.Tensor, ) -> np.ndarray: - if xx is not None: - prec = [key for key, value in PT_PRECISION_DICT.items() if value == xx.dtype] - if len(prec) == 0: - raise ValueError(f"unknown precision {xx.dtype}") - else: - prec = NP_PRECISION_DICT[prec[0]] - return xx.detach().cpu().numpy().astype(prec) if xx is not None else None + if xx is None: + return None + assert xx is not None + # Create a reverse mapping of PT_PRECISION_DICT + reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()} + # Use the reverse mapping to find keys with the desired value + prec = reverse_precision_dict.get(xx.dtype, None) + prec = NP_PRECISION_DICT.get(prec, None) + if prec is None: + raise ValueError(f"unknown precision {xx.dtype}") + return xx.detach().cpu().numpy().astype(prec) def to_torch_tensor( xx: np.ndarray, ) -> torch.Tensor: - if xx is not None: - prec = [key for key, value in NP_PRECISION_DICT.items() if value == xx.dtype] - if len(prec) == 0: - raise ValueError(f"unknown precision {xx.dtype}") - else: - prec = PT_PRECISION_DICT[prec[0]] - return torch.tensor(xx, dtype=prec, device=DEVICE) if xx is not None else None + if xx is None: + return None + assert xx is not None + # Create a reverse mapping of NP_PRECISION_DICT + reverse_precision_dict = {v: k for k, v in NP_PRECISION_DICT.items()} + # Use the reverse mapping to find keys with the desired value + prec = reverse_precision_dict.get(type(xx.flat[0]), None) + prec = PT_PRECISION_DICT.get(prec, None) + if prec is None: + raise ValueError(f"unknown precision {xx.dtype}") + return torch.tensor(xx, dtype=prec, device=DEVICE) diff --git a/source/tests/pt/test_utils.py b/source/tests/pt/test_utils.py new file mode 100644 index 0000000000..9c9a9479ad --- /dev/null +++ b/source/tests/pt/test_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +class TestCvt(unittest.TestCase): + def test_to_numpy(self): + rng = np.random.default_rng() + foo = rng.normal([3, 4]) + for ptp, npp in zip( + [torch.float16, torch.float32, torch.float64], + [np.float16, np.float32, np.float64], + ): + foo = foo.astype(npp) + bar = to_torch_tensor(foo) + self.assertEqual(bar.dtype, ptp) + onk = to_numpy_array(bar) + self.assertEqual(onk.dtype, npp) + with self.assertRaises(ValueError) as ee: + foo = foo.astype(np.int32) + bar = to_torch_tensor(foo) + with self.assertRaises(ValueError) as ee: + bar = to_torch_tensor(foo) + bar = to_numpy_array(bar.int())