Skip to content

Commit

Permalink
fix device inconsistency
Browse files Browse the repository at this point in the history
  • Loading branch information
cherryWangY committed Nov 1, 2024
1 parent 9c7534e commit 5794248
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions source/tests/pt/test_tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import torch

from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.tabulate import (
unaggregated_dy2_dx,
unaggregated_dy2_dx_s,
Expand Down Expand Up @@ -92,15 +95,15 @@ def test_ops(self):
)

dz_pt = unaggregated_dy_dx(
torch.from_numpy(self.y),
torch.from_numpy(self.y).to(env.DEVICE),
self.w,
dy_pt,
torch.from_numpy(self.xbar),
torch.from_numpy(self.xbar).to(env.DEVICE),
1,
)

dz_tf_numpy = dz_tf.numpy()
dz_pt_numpy = dz_pt.detach().numpy()
dz_pt_numpy = dz_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10)

Expand All @@ -114,16 +117,16 @@ def test_ops(self):
)

dy2_pt = unaggregated_dy2_dx(
torch.from_numpy(self.y),
torch.from_numpy(self.y).to(env.DEVICE),
self.w,
dy_pt,
dy2_pt,
torch.from_numpy(self.xbar),
torch.from_numpy(self.xbar).to(env.DEVICE),
1,
)

dy2_tf_numpy = dy2_tf.numpy()
dy2_pt_numpy = dy2_pt.detach().numpy()
dy2_pt_numpy = dy2_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)

Expand Down

0 comments on commit 5794248

Please sign in to comment.