diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index 1bfc53990b..91fc3cabf6 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -11,6 +11,9 @@ from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) dtype = torch.float64 @@ -81,9 +84,7 @@ def np_infer_coord( spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) # detach - ret = { - key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys - } + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} return ret def np_infer_spin( @@ -97,9 +98,7 @@ def np_infer_spin( spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) # detach - ret = { - key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys - } + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} return ret def ff_coord(_coord): @@ -152,9 +151,7 @@ def np_infer( atype, ) # detach - ret = { - key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys - } + ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} # detach return ret diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py index fb84eeae6f..2bd5c22aaf 100644 --- a/source/tests/pt/model/test_ener_spin_model.py +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -310,10 +310,10 @@ def test_dp_consistency(self): dp_model = DPSpinModel.deserialize(self.model.serialize()) # test call dp_ret = dp_model.call( - self.coord.detach().cpu().numpy(), - self.atype.detach().cpu().numpy(), - self.spin.detach().cpu().numpy(), - self.cell.detach().cpu().numpy(), + to_numpy_array(self.coord), + to_numpy_array(self.atype), + to_numpy_array(self.spin), + to_numpy_array(self.cell), ) result = self.model.forward_common( self.coord, @@ -352,11 +352,11 @@ def test_dp_consistency(self): self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 ) dp_ret_lower = dp_model.call_lower( - extended_coord.detach().cpu().numpy(), - extended_atype.detach().cpu().numpy(), - extended_spin.detach().cpu().numpy(), - nlist.detach().cpu().numpy(), - mapping.detach().cpu().numpy(), + to_numpy_array(extended_coord), + to_numpy_array(extended_atype), + to_numpy_array(extended_spin), + to_numpy_array(nlist), + to_numpy_array(mapping), ) result_lower = self.model.forward_common_lower( extended_coord,