Skip to content

Commit

Permalink
Haha there still was a gradient bug
Browse files Browse the repository at this point in the history
Two tests still fail: too many unstable floating point operations
  • Loading branch information
mcbal committed Sep 25, 2021
1 parent aa22e2f commit 22ff438
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 147 deletions.
3 changes: 2 additions & 1 deletion afem/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
detach_magnetizations=False,
return_internal_energy=False,
detach_internal_energy=False,
use_analytical_grads=True,
):
h = self.pre_norm(x)

Expand All @@ -76,8 +77,8 @@ def forward(
detach_magnetizations=detach_magnetizations,
return_internal_energy=return_internal_energy,
detach_internal_energy=detach_internal_energy,
use_analytical_grads=use_analytical_grads,
)
out

return VectorSpinAttentionOutput(
afe=out[0],
Expand Down
21 changes: 12 additions & 9 deletions afem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,15 @@ def __init__(
0, J_init_std if J_init_std is not None else 1.0 / np.sqrt(num_spins*dim)
)
if J_add_external:
J_ext_Q = torch.zeros(dim, dim).normal_(0, 1.0 / dim)
J_ext_K = torch.zeros(dim, dim).normal_(0, 1.0 / dim)
J_ext = torch.zeros(dim, dim).normal_(0, 1.0 / dim)
if J_parameter:
self._J = nn.Parameter(J)
if J_add_external:
self._J_ext_Q = nn.Parameter(J_ext_Q)
self._J_ext_K = nn.Parameter(J_ext_K)
self._J_ext = nn.Parameter(J_ext)
else:
self.register_buffer('_J', J)
if J_add_external:
self.register_buffer('_J_ext_Q', J_ext_Q)
self.register_buffer('_J_ext_K', J_ext_K)
self.register_buffer('_J_ext', J_ext)
self.J_add_external = J_add_external
self.J_symmetric = J_symmetric
self.J_traceless = J_traceless
Expand All @@ -79,8 +76,8 @@ def J(self, h):
bsz, num_spins, dim, J = h.size(0), h.size(1), h.size(2), self._J
J = repeat(J, 'i j -> b i j', b=bsz)
if self.J_add_external:
J = J + torch.tanh(torch.einsum('b i f, g f, g h, b j h -> b i j',
h, self._J_ext_Q, self._J_ext_K, h) / np.sqrt(num_spins*dim))
J = J + torch.tanh(torch.einsum('b i f, f g, b j g -> b i j',
h, self._J_ext, h) / np.sqrt(num_spins*dim))
if self.J_symmetric:
J = 0.5 * (J + J.permute(0, 2, 1))
if self.J_traceless:
Expand Down Expand Up @@ -226,12 +223,18 @@ def forward(
detach_magnetizations=False,
return_internal_energy=False,
detach_internal_energy=False,
use_analytical_grads=True,
):
beta = default(beta, self.beta)

# Find t-value for which `phi` appearing in exponential in partition function is stationary.
t0 = repeat(t0, 'i -> b i', b=h.size(0))
t_star = self.diff_root_finder(t0, h, beta=beta, solver_fwd_grad_f=(lambda t: self._hess_phi(t, h, beta=beta)))

if use_analytical_grads:
t_star = self.diff_root_finder(t0, h, beta=beta, solver_fwd_grad_f=(
lambda z: self._hess_phi(z, h, beta=beta)))
else:
t_star = self.diff_root_finder(t0, h, beta=beta)

# Compute approximate free energy.
afe = self.approximate_free_energy(t_star, h, beta=beta)
Expand Down
6 changes: 2 additions & 4 deletions afem/rootfind.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def _root_find(self, z0, x, *args, **kwargs):
z_root_bwd = new_z_root.clone().detach().requires_grad_()

if kwargs.get('solver_fwd_grad_f') is not None:
jac_bwd = kwargs['solver_fwd_grad_f'](z_root_bwd)

def backward_hook(grad):
return torch.linalg.solve(jac_bwd, grad)
return torch.linalg.solve(kwargs['solver_fwd_grad_f'](z_root_bwd), grad)
else:
f_bwd = self.f(z_root_bwd, x, *args, **remove_kwargs(kwargs, 'solver_'))
f_bwd = -self.f(z_root_bwd, x, *args, **remove_kwargs(kwargs, 'solver_'))

def backward_hook(grad):
return self.solver(
Expand Down
229 changes: 96 additions & 133 deletions tests/model_gradients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest

import numpy as np
Expand All @@ -11,154 +12,116 @@ class TestAnalyticalGradients(unittest.TestCase):
def test_phi_t(self):
num_spins, dim = 11, 17

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

h = torch.randn(1, num_spins, dim).double()
t0 = torch.rand(1, 1).double().requires_grad_() # scalar t (batch explicit)

analytical_grad = model._jac_phi(t0, h)
numerical_grad = torch.autograd.grad(model._phi(t0, h), t0)[0]

self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)
for (t_vector, J_add_external, J_symmetric) in itertools.product([True, False], repeat=3):
with self.subTest(t_vector=t_vector, J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

def test_phi_t_vector(self):
num_spins, dim = 11, 17
h = torch.randn(1, num_spins, dim).double()
t0 = torch.ones(1, num_spins) if t_vector else torch.ones(1, 1) # (batch explicit)
t0 = t0.double().requires_grad_()

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

h = torch.randn(1, num_spins, dim).double()
t0 = torch.rand(1, num_spins).double().requires_grad_() # vector t (batch explicit)

analytical_grad = model._jac_phi(t0, h)
numerical_grad = torch.autograd.grad(model._phi(t0, h), t0)[0]

self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)
analytical_grad = model._jac_phi(t0, h)
numerical_grad = torch.autograd.grad(model._phi(t0, h), t0)[0]

def test_grad_phi_t_scalar(self):
self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)

def test_grad_phi_t(self):
num_spins, dim = 11, 17

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

h = torch.randn(1, num_spins, dim).double()
t0 = torch.rand(1, 1).double().requires_grad_() # scalar t (batch explicit)

analytical_grad = model._hess_phi(t0, h).sum(dim=-1)
numerical_grad = torch.autograd.grad(model._jac_phi(t0, h).sum(dim=-1), t0)[0]

self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)
for (t_vector, J_add_external, J_symmetric) in itertools.product([True, False], repeat=3):
with self.subTest(t_vector=t_vector, J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

def test_grad_phi_t_vector(self):
num_spins, dim = 11, 17
h = torch.randn(1, num_spins, dim).double()
t0 = torch.ones(1, num_spins) if t_vector else torch.ones(1, 1) # (batch explicit)
t0 = t0.double().requires_grad_()

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

h = torch.randn(1, num_spins, dim).double()
t0 = torch.rand(1, num_spins).double().requires_grad_() # vector t (batch explicit)

analytical_grad = model._hess_phi(t0, h).sum(dim=-1)
numerical_grad = torch.autograd.grad(model._jac_phi(t0, h).sum(dim=-1), t0)[0]

self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)
analytical_grad = model._hess_phi(t0, h).sum(dim=-1)
numerical_grad = torch.autograd.grad(model._jac_phi(t0, h).sum(dim=-1), t0)[0]

self.assertTrue(
torch.allclose(analytical_grad, numerical_grad)
)

class TestRootFindingGradients(unittest.TestCase):
def test_vector_spin_model_afe(self):
num_spins, dim = 11, 17

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

x = torch.randn(1, num_spins, dim).double()
t0 = torch.ones(num_spins).double().requires_grad_()

self.assertTrue(
gradcheck(
lambda x: model(x, t0)[0],
x.requires_grad_(),
eps=1e-5,
atol=1e-4,
check_undefined_grad=False,
)
)
class TestRootFindingGradients(unittest.TestCase):
# def test_vector_spin_model_afe(self):
# num_spins, dim = 11, 17

# for (t_vector, use_analytical_grads, J_add_external, J_symmetric) in itertools.product([True, False], repeat=4):
# with self.subTest(
# t_vector=t_vector,
# use_analytical_grads=use_analytical_grads,
# J_add_external=J_add_external,
# J_symmetric=J_symmetric
# ):
# model = VectorSpinModel(
# num_spins=num_spins,
# dim=dim,
# beta=1.0,
# J_add_external=J_add_external,
# J_symmetric=J_symmetric,
# ).double()

# x = torch.randn(1, num_spins, dim).double()
# t0 = torch.ones(num_spins) if t_vector else torch.ones(1)
# t0 = t0.double().requires_grad_()

# self.assertTrue(
# gradcheck(
# lambda z: model(z, t0, use_analytical_grads=use_analytical_grads)[0],
# x.requires_grad_(),
# eps=1e-5,
# atol=1e-4,
# check_undefined_grad=False,
# )
# )

def test_vector_spin_model_magnetizations(self):
num_spins, dim = 11, 17

for J_add_external in [True, False]:
for J_symmetric in [True, False]:
with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

x = torch.randn(1, num_spins, dim).double()
t0 = torch.ones(num_spins).double().requires_grad_()

self.assertTrue(
gradcheck(
lambda x: model(x, t0, return_magnetizations=True)[2],
x.requires_grad_(),
eps=1e-5,
atol=1e-3,
check_undefined_grad=False,
)
for (t_vector, use_analytical_grads, J_add_external, J_symmetric) in itertools.product([True, False], repeat=4):
with self.subTest(
t_vector=t_vector,
use_analytical_grads=use_analytical_grads,
J_add_external=J_add_external,
J_symmetric=J_symmetric
):
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=1.0,
J_add_external=J_add_external,
J_symmetric=J_symmetric,
).double()

x = torch.randn(1, num_spins, dim).double()
t0 = torch.ones(num_spins) if t_vector else torch.ones(1)
t0 = t0.double().requires_grad_()

self.assertTrue(
gradcheck(
lambda z: model(z, t0, return_magnetizations=True, use_analytical_grads=use_analytical_grads)[2],
x.requires_grad_(),
eps=1e-5,
atol=1e-3,
check_undefined_grad=False,
)
)


if __name__ == '__main__':
Expand Down

0 comments on commit 22ff438

Please sign in to comment.