forked from andrewssdd/pulse
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
57 lines (51 loc) · 2.13 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from bicubic import BicubicDownSample
class LossBuilder(torch.nn.Module):
def __init__(self, ref_im, loss_str, eps):
super(LossBuilder, self).__init__()
assert ref_im.shape[2]==ref_im.shape[3]
im_size = ref_im.shape[2]
factor=1024//im_size
assert im_size*factor==1024
self.D = BicubicDownSample(factor=factor)
self.ref_im = ref_im
self.parsed_loss = [loss_term.split('*') for loss_term in loss_str.split('+')]
self.eps = eps
# Takes a list of tensors, flattens them, and concatenates them into a vector
# Used to calculate euclidian distance between lists of tensors
def flatcat(self, l):
l = l if(isinstance(l, list)) else [l]
return torch.cat([x.flatten() for x in l], dim=0)
def _loss_l2(self, gen_im_lr, ref_im, **kwargs):
return ((gen_im_lr - ref_im).pow(2).mean((1, 2, 3)).clamp(min=self.eps).sum())
def _loss_l1(self, gen_im_lr, ref_im, **kwargs):
return 10*((gen_im_lr - ref_im).abs().mean((1, 2, 3)).clamp(min=self.eps).sum())
# Uses geodesic distance on sphere to sum pairwise distances of the 18 vectors
def _loss_geocross(self, latent, **kwargs):
if(latent.shape[1] == 1):
return 0
else:
X = latent.view(-1, 1, 18, 512)
Y = latent.view(-1, 18, 1, 512)
A = ((X-Y).pow(2).sum(-1)+1e-9).sqrt()
B = ((X+Y).pow(2).sum(-1)+1e-9).sqrt()
D = 2*torch.atan2(A, B)
D = ((D.pow(2)*512).mean((1, 2))/8.).sum()
return D
def forward(self, latent, gen_im):
var_dict = {'latent': latent,
'gen_im_lr': self.D(gen_im),
'ref_im': self.ref_im,
}
loss = 0
loss_fun_dict = {
'L2': self._loss_l2,
'L1': self._loss_l1,
'GEOCROSS': self._loss_geocross,
}
losses = {}
for weight, loss_type in self.parsed_loss:
tmp_loss = loss_fun_dict[loss_type](**var_dict)
losses[loss_type] = tmp_loss
loss += float(weight)*tmp_loss
return loss, losses