-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgeometry.py
49 lines (44 loc) · 1.79 KB
/
geometry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
def relaxed_distortion_measure(func, z, eta=0.2, metric='identity', create_graph=True):
if metric == 'identity':
bs = len(z)
z_perm = z[torch.randperm(bs)]
if eta is not None:
alpha = (torch.rand(bs) * (1 + 2*eta) - eta).unsqueeze(1).to(z)
z_augmented = alpha*z + (1-alpha)*z_perm
else:
z_augmented = z
v = torch.randn(z.size()).to(z)
Jv = torch.autograd.functional.jvp(func, z_augmented, v=v, create_graph=create_graph)[1]
TrG = torch.sum(Jv.view(bs, -1)**2, dim=1).mean()
JTJv = (torch.autograd.functional.vjp(func, z_augmented, v=Jv, create_graph=create_graph)[1]).view(bs, -1)
TrG2 = torch.sum(JTJv**2, dim=1).mean()
return TrG2/TrG**2
else:
raise NotImplementedError
def get_flattening_scores(G, mode='condition_number'):
if mode == 'condition_number':
S = torch.svd(G).S
scores = S.max(1).values/S.min(1).values
elif mode == 'variance':
G_mean = torch.mean(G, dim=0, keepdim=True)
A = torch.inverse(G_mean)@G
scores = torch.sum(torch.log(torch.svd(A).S)**2, dim=1)
else:
pass
return scores
def jacobian_decoder_jvp_parallel(func, inputs, v=None, create_graph=True):
batch_size, z_dim = inputs.size()
if v is None:
v = torch.eye(z_dim).unsqueeze(0).repeat(batch_size, 1, 1).view(-1, z_dim).to(inputs)
inputs = inputs.repeat(1, z_dim).view(-1, z_dim)
jac = (
torch.autograd.functional.jvp(
func, inputs, v=v, create_graph=create_graph
)[1].view(batch_size, z_dim, -1).permute(0, 2, 1)
)
return jac
def get_pullbacked_Riemannian_metric(func, z):
J = jacobian_decoder_jvp_parallel(func, z, v=None)
G = torch.einsum('nij,nik->njk', J, J)
return G