Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbal committed Sep 12, 2021
1 parent 69a4479 commit c41d460
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion tests/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def test_vector_spin_model_forward_afe(self):
)
)

def test_vector_spin_model_forward_afe_asym(self):
num_spins, dim = 11, 17

model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_symmetric=False,
).double()

x = torch.randn(3, num_spins, dim).double()

self.assertTrue(
gradcheck(
lambda x: model(x)[0],
x.requires_grad_(),
eps=1e-5,
atol=1e-4,
check_undefined_grad=False,
)
)

def test_vector_spin_model_forward_responses(self):
num_spins, dim = 11, 17

Expand All @@ -82,7 +104,29 @@ def test_vector_spin_model_forward_responses(self):

self.assertTrue(
gradcheck(
lambda x: model(x, return_responses=True)[2],
lambda x: model(x, return_magnetizations=True)[2],
x.requires_grad_(),
eps=1e-5,
atol=1e-4,
check_undefined_grad=False,
)
)

def test_vector_spin_model_forward_responses_asym(self):
num_spins, dim = 11, 17

model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_symmetric=False,
).double()

x = torch.randn(1, num_spins, dim).double()

self.assertTrue(
gradcheck(
lambda x: model(x, return_magnetizations=True)[2],
x.requires_grad_(),
eps=1e-5,
atol=1e-4,
Expand Down

0 comments on commit c41d460

Please sign in to comment.