From 3fdd86fb829f263040937479a63a8c1ee065d664 Mon Sep 17 00:00:00 2001 From: mcbal Date: Mon, 20 Sep 2021 21:49:04 +0200 Subject: [PATCH] Bingo dump --- afem/models.py | 81 +++++---------- afem/rootfind.py | 4 +- afem/solvers.py | 13 +-- examples/mnist_afe.py | 207 +++++++++++++++++++++++--------------- examples/model_fwd_bwd.py | 50 ++++----- tests/model_gradients.py | 28 +++++- 6 files changed, 205 insertions(+), 178 deletions(-) diff --git a/afem/models.py b/afem/models.py index 161ae95..daec50b 100644 --- a/afem/models.py +++ b/afem/models.py @@ -66,9 +66,9 @@ def __init__( self.diff_root_finder = RootFind( self._grad_t_phi, newton, - solver_fwd_max_iter=10, + solver_fwd_max_iter=30, solver_fwd_tol=1e-5, - solver_bwd_max_iter=10, + solver_bwd_max_iter=30, solver_bwd_tol=1e-5, ) @@ -81,10 +81,8 @@ def J(self, h): bsz, num_spins, dim, J = h.size(0), h.size(1), h.size(2), self._J J = repeat(J, 'i j -> b i j', b=bsz) if self.J_add_external: - ext = torch.tanh(torch.einsum('b i f, f g, b j g -> b i j', - h, self._J_ext, h) / np.sqrt(num_spins*dim)) - print(torch.linalg.norm(J), torch.linalg.norm(ext)) - J = J + ext + J = J + torch.tanh(torch.einsum('b i f, f g, b j g -> b i j', + h, self._J_ext, h) / np.sqrt(num_spins*dim)) if self.J_symmetric: J = 0.5 * (J + J.permute(0, 2, 1)) if self.J_traceless: @@ -97,10 +95,6 @@ def _phi_prep(self, t, J): assert t.ndim == 2, f'Tensor `t` should have either shape (batch, 1) or (batch, N) but found shape {t.shape}' t = t.repeat(1, self.num_spins) if t.size(-1) == 1 else t V = torch.diag_embed(t) - J - - # print(t, torch.eig(V[0]).eigenvalues, torch.det(V[0])) - # breakpoint() - V_inv = torch.linalg.solve(V, batch_eye_like(V)) return t, V, V_inv @@ -112,16 +106,10 @@ def _phi(self, t, h, beta=None, J=None): """ beta, J = default(beta, self.beta), default(J, self.J(h)) t, V, V_inv = self._phi_prep(t, J) - # print( - # t[0][0], - # torch.linalg.norm(beta * t.sum(dim=-1)[0]), - # torch.linalg.norm(0.5 * torch.logdet(V)[0]), - # torch.linalg.norm(beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h)[0]),) - kak = ( + return ( beta * t.sum(dim=-1) - 0.5 * torch.logdet(V) + beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h) )[:, None] - return kak def _grad_t_phi(self, t, h, beta=None, J=None): """Compute gradient of `phi` with respect to auxiliary variables `t`. @@ -132,21 +120,30 @@ def _grad_t_phi(self, t, h, beta=None, J=None): """ beta, J = default(beta, self.beta), default(J, self.J(h)) _, _, V_inv = self._phi_prep(t, J) - kak = ( - beta * self.num_spins - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1).sum(dim=-1) - - beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv) - )[:, None] - - # print('J', J, torch.norm(J)) - # print('Vinv', V_inv) - # print('h term', - beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv)) - - return kak + if t.size(-1) == 1: + return ( + beta * self.num_spins - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1).sum(dim=-1) + - beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv) + )[:, None] + else: + # vector t (different auxiliaries for every spin) + return ( + beta * torch.ones_like(t) - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1) + - beta / (4.0 * self.dim) * torch.einsum('b k f, b l f, b i k, b l i -> b i', h, h, V_inv, V_inv) + ) def approximate_log_Z(self, t, h, beta=None): beta = default(beta, self.beta) return 0.5 * self.num_spins * torch.log(math.pi / beta) + self._phi(t, h, beta=beta) + def loss(self, t, h, beta=None): + beta, J = default(beta, self.beta), self.J(h) + t, V, V_inv = self._phi_prep(t, J) + kaka = (0.5 * torch.logdet(V) - beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h) + )[:, None] + + return kaka + def approximate_free_energy(self, t, h, beta=None): """Compute steepest-descent free energy for large `self.dim`. @@ -155,34 +152,6 @@ def approximate_free_energy(self, t, h, beta=None): with respect to gradient-requiring parameters (careful for implicit dependencies when evaluating `t` away from the stationary point where phi'(t*) != 0). """ - - # with torch.no_grad(): - # import matplotlib.pyplot as plt - - # def filter_array(a, threshold=1e2): - # idx = np.where(np.abs(a) > threshold) - # a[idx] = np.nan - # return a - # # Pick a range and resolution for `t`. - # t_range = torch.arange(0.0, 3.0, 0.001)[:, None] - # # Calculate function evaluations for every point on grid and plot. - # out = np.array(self._phi(t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) - # out_grad = np.array(self._grad_t_phi( - # t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) - # f, (ax1, ax2) = plt.subplots(2, 1, sharex=True) - # ax1.set_title(f"(Hopefully) Found root of phi'(t) at t = {t[0][0].detach().numpy()}") - # ax1.plot(t_range.numpy().squeeze(), filter_array(out), 'r-') - # ax1.axvline(x=t[0].detach().numpy()) - # ax1.set_ylabel("phi(t)") - # ax2.plot(t_range.numpy().squeeze(), filter_array(out_grad), 'r-') - # ax2.axvline(x=t[0].detach().numpy()) - # ax2.axhline(y=0.0) - # ax2.set_xlabel('t') - # ax2.set_ylabel("phi'(t)") - # # # plt.show() - # from datetime import datetime - # plt.savefig(f'{datetime.now()}.png') - return -1.0 / beta * self.approximate_log_Z(t, h, beta=beta) def internal_energy(self, t, h, beta, detach=False): @@ -232,7 +201,7 @@ def forward( # Find t-value for which `phi` appearing in exponential in partition function is stationary. t_star = self.diff_root_finder( - t0*torch.ones(h.size(0), 1, device=h.device, dtype=h.dtype), h, beta=beta, + t0*torch.ones(h.size(0), self.num_spins, device=h.device, dtype=h.dtype), h, beta=beta, ) # Compute approximate free energy. diff --git a/afem/rootfind.py b/afem/rootfind.py index 474172c..46d6288 100644 --- a/afem/rootfind.py +++ b/afem/rootfind.py @@ -45,12 +45,10 @@ def _root_find(self, z0, x, *args, **kwargs): fun_bwd = self.fun(z_bwd, x, *args, **remove_kwargs(kwargs, 'solver_')) def backward_hook(grad): - aa = self.solver( + return self.solver( lambda y: autograd.grad(fun_bwd, z_bwd, y, retain_graph=True, create_graph=True)[0] + grad, torch.zeros_like(grad), **filter_kwargs(kwargs, 'solver_bwd_') )['result'] - # print('back', grad) - return aa new_z_root.register_hook(backward_hook) diff --git a/afem/solvers.py b/afem/solvers.py index a591a70..9f2df13 100644 --- a/afem/solvers.py +++ b/afem/solvers.py @@ -6,12 +6,11 @@ def _reset_singular_jacobian(x): """Check for singular scalars/matrices in batch; reset singular scalars/matrices to ones.""" bad_idxs = torch.isclose(x, torch.zeros_like( - x)) if x.size(-1) == 1 else torch.isclose(torch.det(x), torch.zeros_like(x)) + x)) if x.size(-1) == 1 else torch.isclose(torch.linalg.det(x), torch.zeros_like(x[:, 0, 0])) if bad_idxs.any(): print( f'🔔 Encountered {bad_idxs.sum()} singular Jacobian(s) in current batch during root-finding. Jumping to somewhere else.' ) - breakpoint() x[bad_idxs] = 1.0 return x @@ -30,14 +29,9 @@ def jacobian(f, z): return analytical_jac_f(z) if analytical_jac_f is not None else batch_jacobian(f, z) def g(z): - kaka = _reset_singular_jacobian(jacobian(f, z)) - # print(kaka) - bla = torch.linalg.solve(kaka, f(z)) - # print('inside g', z, bla, f(z)) - return z - bla + return z - torch.linalg.solve(_reset_singular_jacobian(jacobian(f, z)), f(z)) z_prev, z, n_steps, trace = z_init, g(z_init), 0, [] - trace.append(torch.linalg.norm(f(z_init)).detach()) trace.append(torch.linalg.norm(f(z)).detach()) @@ -46,9 +40,6 @@ def g(z): n_steps += 1 trace.append(torch.linalg.norm(f(z)).detach()) - # print(z_init, trace, z) - # print(trace) - return { 'result': z, 'n_steps': n_steps, diff --git a/examples/mnist_afe.py b/examples/mnist_afe.py index 4719c6c..42b4107 100644 --- a/examples/mnist_afe.py +++ b/examples/mnist_afe.py @@ -2,6 +2,8 @@ # approximate free energy loss and then tested using inference # on classification site starting from random vector (or zeros). +import matplotlib.pyplot as plt + import torch import numpy as np from torch import nn, optim @@ -14,38 +16,84 @@ from afem.attention import VectorSpinAttention +def sample(model, device): + x = ((torch.zeros(1, 1, 28, 28)-0.1307)/0.3081).requires_grad_() + # x = torch.rand(1, 1, 28, 28).requires_grad_() + # print(x) + eps = 0.01 + langevin_noise = torch.distributions.Normal( + torch.zeros(x.shape), + torch.ones(x.shape) * eps + ) + for name, param in model.named_parameters(): + param.requires_grad = False + + for i in range(300): + fun = model(x)[0] + print(i, fun) + grad_ea = torch.autograd.grad(fun, x)[0] + # print(grad_ea) + # grad_ea = torch.clamp(grad_ea, -self.opt.grad_clip_sampling, self.opt.grad_clip_sampling) + x = x - 1.0 / 2.0 * grad_ea + langevin_noise.sample() + x = x.clamp(-0.1307/0.3081, (1.0-0.1307)/0.3081) + # if i % 10 == 0: + # image = ((0.1307+x*0.3081).clone().detach()).squeeze(0).squeeze(0).numpy() + # # if i % 20 == 0: + # plt.imshow(image, cmap='gray') + # plt.colorbar() + # plt.show() + + # x = torch.clamp(x, -0.1307/0.3081, (1.0-0.1307)/0.3081) + image = ((0.1307+x*0.3081).clone().detach()).squeeze(0).squeeze(0).numpy() + # if i % 20 == 0: + plt.imshow(image, cmap='gray') + plt.colorbar() + plt.show() + for name, param in model.named_parameters(): + param.requires_grad = True + + def train(model, device, train_loader, optimizer, epoch): model.train() - for name, param in model.named_parameters(): - if 'attention' in name: - param.requires_grad = False - else: - param.requires_grad = True + # for name, param in model.named_parameters(): + # param.requires_grad = True with tqdm(train_loader, unit='it') as tqdm_loader: - for _, (data, target) in enumerate(tqdm_loader): + for idx, (data, target) in enumerate(tqdm_loader): + + # plt.imshow((0.1307+0.3081*data[0].clone().detach().squeeze(0).squeeze(0)).numpy(), cmap='gray') + # plt.colorbar() + # plt.show() + tqdm_loader.set_description(f'Epoch {epoch}') optimizer.zero_grad() data, target = data.to(device), target.to(device) - loss = model(data, target) + loss_afem, logits, loss_xe = model(data, target) + loss = loss_afem + loss_xe if torch.isnan(loss): optimizer.zero_grad() continue # loss = F.cross_entropy(output, target) - # preds = output.argmax(dim=1, keepdim=True) - # correct = preds.eq(target.view_as(preds)).sum().item() - # accuracy = correct / target.shape[0] + preds = logits.argmax(dim=1, keepdim=True) + correct = preds.eq(target.view_as(preds)).sum().item() + accuracy = correct / target.shape[0] loss.backward() - # print(torch.linalg.norm(model.attention.spin_model._J)) + # print(model.attention.spin_model._J.grad) # breakpoint() - # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) + + # print(model.attention.spin_model._J.grad) optimizer.step() - tqdm_loader.set_postfix(afe_loss=f'{loss.item():.4f}') + tqdm_loader.set_postfix(loss_afem=f'{loss_afem.item():.4f}', + loss_xe=f'{loss_xe.item():.4f}', accuracy=f'{accuracy:.4f}') + + if idx % 1000 == 0: + sample(model, device) def test(model, device, test_loader): @@ -53,40 +101,19 @@ def test(model, device, test_loader): test_loss = 0 correct = 0 - # with torch.no_grad(): - for param in model.parameters(): - param.requires_grad = False - counter = 0 - - for data, target in test_loader: - - bsz = data.size(0) - - y = torch.randn(bsz, 1, 64).requires_grad_() - optimizer = optim.Adam([y], lr=1e-3) - - for i in range(10): - data, target = data.clone().to(device), target.clone().to(device) - optimizer.zero_grad() - - loss, preds = model(data, y) - # print(preds.indices.shape) - - loss.backward() - optimizer.step() - - # preds = output.argmax(dim=1, keepdim=True) - bla = preds.indices.eq(target.view_as(preds.indices)).sum().item() - - correct += bla - counter += bsz - print(i, loss, correct / counter) - del optimizer + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + lossa, logits, lossb = model(data, target) + loss = lossa + lossb + test_loss += loss.item() + preds = logits.argmax(dim=1, keepdim=True) + correct += preds.eq(target.view_as(preds)).sum().item() test_loss /= len(test_loader.dataset) print( '\n✨ Test set: Average loss: {:.4f}, Accuracy: {}/{})\n'.format( - 0.0, + test_loss, correct, len(test_loader.dataset), ) @@ -94,7 +121,7 @@ def test(model, device, test_loader): class MNISTNet(nn.Module): - def __init__(self, dim=64, dim_conv=32, num_spins=16+1): + def __init__(self, dim=32, dim_conv=24, num_spins=16+1): super(MNISTNet, self).__init__() self.to_patch_embedding = nn.Sequential( @@ -107,54 +134,68 @@ def __init__(self, dim=64, dim_conv=32, num_spins=16+1): Rearrange('b c h w -> b (h w) c'), # -> (4 x 4) x dim_conv nn.Linear(dim_conv, dim), # -> (4 x 4) x dim ) - # self.cls_token = nn.Parameter(torch.randn(1, 32, dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # self.cls_emb = torch.randn(10, dim) - self.cls_emb = nn.Embedding(10, dim) + # self.cls_emb = nn.Embedding(10, dim) self.attention = VectorSpinAttention( - num_spins=num_spins, dim=dim, pre_norm=True, beta=1.0, use_scalenorm=True, J_symmetric=True, - J_traceless=True, J_add_external=False) - self.final = nn.Linear(10, dim) + num_spins=num_spins, dim=dim, pre_norm=True, post_norm=True, beta=1.0, use_scalenorm=True, + J_symmetric=True, J_traceless=True, J_add_external=True) + self.final = nn.Linear(dim, 10) self.t0 = 1.0 # print(self.t0) + from afem.modules import ScaleNorm + self.prenorm = ScaleNorm(dim) # breakpoint() self.prev_J = None - def forward(self, x, y): + def forward(self, x, y=None): + # print(x[0]) + # breakpoint() + # print(x[0]) + # plt.imshow((0.1307+0.3081*x[0].clone().detach().squeeze(0).squeeze(0)).numpy(), cmap='gray') + # plt.colorbar() + # plt.show() + x = self.to_patch_embedding(x) # x = torch.cat((x, torch.zeros((x.shape[0], 32, x.shape[-1]))), dim=1) - # cls_tokens = self.cls_token.repeat(x.shape[0], 1, 1) - # x = torch.cat((x, cls_tokens), dim=1) - + cls_tokens = self.cls_token.repeat(x.shape[0], 1, 1) + x = torch.cat((x, cls_tokens), dim=1) + # x = torch.cat((x, torch.zeros((x.shape[0], 1, x.shape[-1]))), dim=1) # print(self.training) - if self.training: - # print(y) - # y = torch.nn.functional.embedding(y, self.cls_emb).unsqueeze(1) - y = torch.nn.functional.one_hot(y, num_classes=10).float() - y = self.final(y).unsqueeze(1) - # print(y) - x = torch.cat((x, y), dim=1) - - afe, t_star = self.attention(x, t0=self.t0, return_magnetizations=False) - - # self.t0 = t_star[0][0].detach().clone() + 0.1 - - # if self.prev_J is not None: - # print(self.prev_J, torch.eig(self.prev_J).eigenvalues, torch.det(self.prev_J)) - # breakpoint() - # self.prev_J = self.attention.spin_model.J(x)[0].detach().clone() - - return afe.mean() / x.size(1) # because this returns - else: - # print(y.shape) - x = torch.cat((x, y), dim=1) - afe, _ = self.attention(x, t0=self.t0, return_magnetizations=False) - # print(afe) - dist = torch.norm(y.repeat(1, 10, 1) - self.cls_emb.weight.repeat(x.size(0), 1, 1), dim=-1) - # print(dist.shape) - preds = dist.topk(1, largest=False) - return afe.mean() / x.size(1), preds + # if self.training: + # print(y) + # y = torch.nn.functional.embedding(y, self.cls_emb).unsqueeze(1) + # y = self.cls_emb(y).unsqueeze(1) + # print(y) + # x = torch.cat((x, y), dim=1) + + resps, afe, t_star = self.attention(x, t0=self.t0, return_magnetizations=True) + # print(t_star) + # if (t_star > 1.4).any(): + # self.t0 = 2.0 + + logits = self.final(resps[:, -1, :]) + + loss_afem = self.attention.spin_model.loss(t_star, self.prenorm(x)).mean() # / (33) + # loss_afem = torch.randn(1, 1) + out = (loss_afem, logits) + + if y is not None: + loss_xe = F.cross_entropy(logits, y) + out += (loss_xe,) + + return out + # else: + # # print(y.shape) + # x = torch.cat((x, y), dim=1) + # afe, t_star_ = self.attention(x, t0=self.t0, return_magnetizations=False) + # # print(afe) + # dist = torch.norm(y.repeat(1, 10, 1) - self.cls_emb.weight.repeat(x.size(0), 1, 1), dim=-1) + # # print(dist.shape) + # preds = dist.topk(1, largest=False) + # return self.attention.spin_model.loss(t_star_, x).mean(), preds def main(): @@ -181,11 +222,15 @@ def main(): f'\n✨ Initialized {model.__class__.__name__} ({sum(p.nelement() for p in model.parameters())} params) on {device}.' ) # number of model parameters may be an overestimate if weights are made symmetric/traceless inside model - train_optimizer = optim.Adam(model.parameters(), lr=1e-3) + train_optimizer = optim.Adam(model.parameters(), lr=5e-4) # test(model, device, test_loader) for epoch in range(1, 30 + 1): + + # sample(model, device) + train(model, device, train_loader, train_optimizer, epoch) + test(model, device, test_loader) diff --git a/examples/model_fwd_bwd.py b/examples/model_fwd_bwd.py index 7707d08..daf4433 100644 --- a/examples/model_fwd_bwd.py +++ b/examples/model_fwd_bwd.py @@ -21,9 +21,9 @@ x = torch.randn(1, num_spins, dim).requires_grad_() # Run forward pass and get responses, approximate free energy, and stationary t_star value. -responses, afe, t_star = attention(x, t0=1.0/np.sqrt(num_spins*dim)) +responses, afe, t_star = attention(x, t0=1.0) -print(f'✨ afe / N: {afe.item()/num_spins:.4f}, \n✨ t_star {t_star.item():.4f}, \n✨ responses: {responses}') +print(f'✨ afe / N: {afe.item()/num_spins:.4f}, \n✨ t_star {t_star}, \n✨ responses: {responses}') # Run backward on sum of free energies across batch dimension. afe.sum().backward() @@ -33,26 +33,26 @@ ######################################################################################### -def filter_array(a, threshold=1e2): - idx = np.where(np.abs(a) > threshold) - a[idx] = np.nan - return a - - -# Pick a range and resolution for `t`. -t_range = torch.arange(0.0, 3.0, 0.0001)[:, None] -# Calculate function evaluations for every point on grid and plot. -out = np.array(attention.spin_model._phi(t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) -out_grad = np.array(attention.spin_model._grad_t_phi( - t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) -f, (ax1, ax2) = plt.subplots(2, 1, sharex=True) -ax1.set_title(f"(Hopefully) Found root of phi'(t) at t = {t_star[0][0].detach().numpy()}") -ax1.plot(t_range.numpy().squeeze(), filter_array(out), 'r-') -ax1.axvline(x=t_star[0].detach().numpy()) -ax1.set_ylabel("phi(t)") -ax2.plot(t_range.numpy().squeeze(), filter_array(out_grad), 'r-') -ax2.axvline(x=t_star[0].detach().numpy()) -ax2.axhline(y=0.0) -ax2.set_xlabel('t') -ax2.set_ylabel("phi'(t)") -plt.show() +# def filter_array(a, threshold=1e2): +# idx = np.where(np.abs(a) > threshold) +# a[idx] = np.nan +# return a + + +# # Pick a range and resolution for `t`. +# t_range = torch.arange(0.0, 3.0, 0.0001)[:, None] +# # Calculate function evaluations for every point on grid and plot. +# out = np.array(attention.spin_model._phi(t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) +# out_grad = np.array(attention.spin_model._grad_t_phi( +# t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy()) +# f, (ax1, ax2) = plt.subplots(2, 1, sharex=True) +# ax1.set_title(f"(Hopefully) Found root of phi'(t) at t = {t_star[0][0].detach().numpy()}") +# ax1.plot(t_range.numpy().squeeze(), filter_array(out), 'r-') +# ax1.axvline(x=t_star[0].detach().numpy()) +# ax1.set_ylabel("phi(t)") +# ax2.plot(t_range.numpy().squeeze(), filter_array(out_grad), 'r-') +# ax2.axvline(x=t_star[0].detach().numpy()) +# ax2.axhline(y=0.0) +# ax2.set_xlabel('t') +# ax2.set_ylabel("phi'(t)") +# plt.show() diff --git a/tests/model_gradients.py b/tests/model_gradients.py index a093ac4..5ee0cba 100644 --- a/tests/model_gradients.py +++ b/tests/model_gradients.py @@ -8,7 +8,31 @@ class TestAnalyticalGradients(unittest.TestCase): - def test_phi_t(self): + # def test_phi_t(self): + # num_spins, dim = 11, 17 + + # for J_add_external in [True, False]: + # for J_symmetric in [True, False]: + # with self.subTest(J_add_external=J_add_external, J_symmetric=J_symmetric): + # model = VectorSpinModel( + # num_spins=num_spins, + # dim=dim, + # beta=1.0, + # J_add_external=J_add_external, + # J_symmetric=J_symmetric, + # ).double() + + # h = torch.randn(1, num_spins, dim).double() + # t0 = torch.rand(1, 1).double().requires_grad_() + + # analytical_grad = model._grad_t_phi(t0, h) + # numerical_grad = torch.autograd.grad(model._phi(t0, h), t0)[0] + + # self.assertTrue( + # torch.allclose(analytical_grad, numerical_grad) + # ) + + def test_phi_t_vector(self): num_spins, dim = 11, 17 for J_add_external in [True, False]: @@ -23,7 +47,7 @@ def test_phi_t(self): ).double() h = torch.randn(1, num_spins, dim).double() - t0 = torch.rand(1, 1).double().requires_grad_() + t0 = torch.rand(1, num_spins).double().requires_grad_() analytical_grad = model._grad_t_phi(t0, h) numerical_grad = torch.autograd.grad(model._phi(t0, h), t0)[0]