Skip to content

Commit

Permalink
use to_numpy_array
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Mar 6, 2024
1 parent 9dccb1c commit 3cb3316
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
15 changes: 6 additions & 9 deletions source/tests/pt/model/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)

dtype = torch.float64

Expand Down Expand Up @@ -81,9 +84,7 @@ def np_infer_coord(
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
)
# detach
ret = {
key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys
}
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
return ret

def np_infer_spin(
Expand All @@ -97,9 +98,7 @@ def np_infer_spin(
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
)
# detach
ret = {
key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys
}
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
return ret

def ff_coord(_coord):
Expand Down Expand Up @@ -152,9 +151,7 @@ def np_infer(
atype,
)
# detach
ret = {
key: result[key].squeeze(0).detach().cpu().numpy() for key in test_keys
}
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
# detach
return ret

Expand Down
18 changes: 9 additions & 9 deletions source/tests/pt/model/test_ener_spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def test_dp_consistency(self):
dp_model = DPSpinModel.deserialize(self.model.serialize())
# test call
dp_ret = dp_model.call(
self.coord.detach().cpu().numpy(),
self.atype.detach().cpu().numpy(),
self.spin.detach().cpu().numpy(),
self.cell.detach().cpu().numpy(),
to_numpy_array(self.coord),
to_numpy_array(self.atype),
to_numpy_array(self.spin),
to_numpy_array(self.cell),
)
result = self.model.forward_common(
self.coord,
Expand Down Expand Up @@ -352,11 +352,11 @@ def test_dp_consistency(self):
self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1
)
dp_ret_lower = dp_model.call_lower(
extended_coord.detach().cpu().numpy(),
extended_atype.detach().cpu().numpy(),
extended_spin.detach().cpu().numpy(),
nlist.detach().cpu().numpy(),
mapping.detach().cpu().numpy(),
to_numpy_array(extended_coord),
to_numpy_array(extended_atype),
to_numpy_array(extended_spin),
to_numpy_array(nlist),
to_numpy_array(mapping),
)
result_lower = self.model.forward_common_lower(
extended_coord,
Expand Down

0 comments on commit 3cb3316

Please sign in to comment.