Skip to content

Commit

Permalink
Bingo dump
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbal committed Sep 20, 2021
1 parent e5a37ec commit 3fdd86f
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 178 deletions.
81 changes: 25 additions & 56 deletions afem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(
self.diff_root_finder = RootFind(
self._grad_t_phi,
newton,
solver_fwd_max_iter=10,
solver_fwd_max_iter=30,
solver_fwd_tol=1e-5,
solver_bwd_max_iter=10,
solver_bwd_max_iter=30,
solver_bwd_tol=1e-5,
)

Expand All @@ -81,10 +81,8 @@ def J(self, h):
bsz, num_spins, dim, J = h.size(0), h.size(1), h.size(2), self._J
J = repeat(J, 'i j -> b i j', b=bsz)
if self.J_add_external:
ext = torch.tanh(torch.einsum('b i f, f g, b j g -> b i j',
h, self._J_ext, h) / np.sqrt(num_spins*dim))
print(torch.linalg.norm(J), torch.linalg.norm(ext))
J = J + ext
J = J + torch.tanh(torch.einsum('b i f, f g, b j g -> b i j',
h, self._J_ext, h) / np.sqrt(num_spins*dim))
if self.J_symmetric:
J = 0.5 * (J + J.permute(0, 2, 1))
if self.J_traceless:
Expand All @@ -97,10 +95,6 @@ def _phi_prep(self, t, J):
assert t.ndim == 2, f'Tensor `t` should have either shape (batch, 1) or (batch, N) but found shape {t.shape}'
t = t.repeat(1, self.num_spins) if t.size(-1) == 1 else t
V = torch.diag_embed(t) - J

# print(t, torch.eig(V[0]).eigenvalues, torch.det(V[0]))
# breakpoint()

V_inv = torch.linalg.solve(V, batch_eye_like(V))
return t, V, V_inv

Expand All @@ -112,16 +106,10 @@ def _phi(self, t, h, beta=None, J=None):
"""
beta, J = default(beta, self.beta), default(J, self.J(h))
t, V, V_inv = self._phi_prep(t, J)
# print(
# t[0][0],
# torch.linalg.norm(beta * t.sum(dim=-1)[0]),
# torch.linalg.norm(0.5 * torch.logdet(V)[0]),
# torch.linalg.norm(beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h)[0]),)
kak = (
return (
beta * t.sum(dim=-1) - 0.5 * torch.logdet(V)
+ beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h)
)[:, None]
return kak

def _grad_t_phi(self, t, h, beta=None, J=None):
"""Compute gradient of `phi` with respect to auxiliary variables `t`.
Expand All @@ -132,21 +120,30 @@ def _grad_t_phi(self, t, h, beta=None, J=None):
"""
beta, J = default(beta, self.beta), default(J, self.J(h))
_, _, V_inv = self._phi_prep(t, J)
kak = (
beta * self.num_spins - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1).sum(dim=-1)
- beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv)
)[:, None]

# print('J', J, torch.norm(J))
# print('Vinv', V_inv)
# print('h term', - beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv))

return kak
if t.size(-1) == 1:
return (
beta * self.num_spins - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1).sum(dim=-1)
- beta / (4.0 * self.dim) * torch.einsum('b i f, b j f, b i k, b k j -> b', h, h, V_inv, V_inv)
)[:, None]
else:
# vector t (different auxiliaries for every spin)
return (
beta * torch.ones_like(t) - 0.5 * torch.diagonal(V_inv, dim1=-2, dim2=-1)
- beta / (4.0 * self.dim) * torch.einsum('b k f, b l f, b i k, b l i -> b i', h, h, V_inv, V_inv)
)

def approximate_log_Z(self, t, h, beta=None):
beta = default(beta, self.beta)
return 0.5 * self.num_spins * torch.log(math.pi / beta) + self._phi(t, h, beta=beta)

def loss(self, t, h, beta=None):
beta, J = default(beta, self.beta), self.J(h)
t, V, V_inv = self._phi_prep(t, J)
kaka = (0.5 * torch.logdet(V) - beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h)
)[:, None]

return kaka

def approximate_free_energy(self, t, h, beta=None):
"""Compute steepest-descent free energy for large `self.dim`.
Expand All @@ -155,34 +152,6 @@ def approximate_free_energy(self, t, h, beta=None):
with respect to gradient-requiring parameters (careful for implicit dependencies when
evaluating `t` away from the stationary point where phi'(t*) != 0).
"""

# with torch.no_grad():
# import matplotlib.pyplot as plt

# def filter_array(a, threshold=1e2):
# idx = np.where(np.abs(a) > threshold)
# a[idx] = np.nan
# return a
# # Pick a range and resolution for `t`.
# t_range = torch.arange(0.0, 3.0, 0.001)[:, None]
# # Calculate function evaluations for every point on grid and plot.
# out = np.array(self._phi(t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
# out_grad = np.array(self._grad_t_phi(
# t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
# f, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
# ax1.set_title(f"(Hopefully) Found root of phi'(t) at t = {t[0][0].detach().numpy()}")
# ax1.plot(t_range.numpy().squeeze(), filter_array(out), 'r-')
# ax1.axvline(x=t[0].detach().numpy())
# ax1.set_ylabel("phi(t)")
# ax2.plot(t_range.numpy().squeeze(), filter_array(out_grad), 'r-')
# ax2.axvline(x=t[0].detach().numpy())
# ax2.axhline(y=0.0)
# ax2.set_xlabel('t')
# ax2.set_ylabel("phi'(t)")
# # # plt.show()
# from datetime import datetime
# plt.savefig(f'{datetime.now()}.png')

return -1.0 / beta * self.approximate_log_Z(t, h, beta=beta)

def internal_energy(self, t, h, beta, detach=False):
Expand Down Expand Up @@ -232,7 +201,7 @@ def forward(

# Find t-value for which `phi` appearing in exponential in partition function is stationary.
t_star = self.diff_root_finder(
t0*torch.ones(h.size(0), 1, device=h.device, dtype=h.dtype), h, beta=beta,
t0*torch.ones(h.size(0), self.num_spins, device=h.device, dtype=h.dtype), h, beta=beta,
)

# Compute approximate free energy.
Expand Down
4 changes: 1 addition & 3 deletions afem/rootfind.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ def _root_find(self, z0, x, *args, **kwargs):
fun_bwd = self.fun(z_bwd, x, *args, **remove_kwargs(kwargs, 'solver_'))

def backward_hook(grad):
aa = self.solver(
return self.solver(
lambda y: autograd.grad(fun_bwd, z_bwd, y, retain_graph=True, create_graph=True)[0] + grad,
torch.zeros_like(grad), **filter_kwargs(kwargs, 'solver_bwd_')
)['result']
# print('back', grad)
return aa

new_z_root.register_hook(backward_hook)

Expand Down
13 changes: 2 additions & 11 deletions afem/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
def _reset_singular_jacobian(x):
"""Check for singular scalars/matrices in batch; reset singular scalars/matrices to ones."""
bad_idxs = torch.isclose(x, torch.zeros_like(
x)) if x.size(-1) == 1 else torch.isclose(torch.det(x), torch.zeros_like(x))
x)) if x.size(-1) == 1 else torch.isclose(torch.linalg.det(x), torch.zeros_like(x[:, 0, 0]))
if bad_idxs.any():
print(
f'🔔 Encountered {bad_idxs.sum()} singular Jacobian(s) in current batch during root-finding. Jumping to somewhere else.'
)
breakpoint()
x[bad_idxs] = 1.0
return x

Expand All @@ -30,14 +29,9 @@ def jacobian(f, z):
return analytical_jac_f(z) if analytical_jac_f is not None else batch_jacobian(f, z)

def g(z):
kaka = _reset_singular_jacobian(jacobian(f, z))
# print(kaka)
bla = torch.linalg.solve(kaka, f(z))
# print('inside g', z, bla, f(z))
return z - bla
return z - torch.linalg.solve(_reset_singular_jacobian(jacobian(f, z)), f(z))

z_prev, z, n_steps, trace = z_init, g(z_init), 0, []

trace.append(torch.linalg.norm(f(z_init)).detach())
trace.append(torch.linalg.norm(f(z)).detach())

Expand All @@ -46,9 +40,6 @@ def g(z):
n_steps += 1
trace.append(torch.linalg.norm(f(z)).detach())

# print(z_init, trace, z)
# print(trace)

return {
'result': z,
'n_steps': n_steps,
Expand Down
Loading

0 comments on commit 3fdd86f

Please sign in to comment.