-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
63 lines (46 loc) · 2.29 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
class Sobel(nn.Module):
def __init__(self):
super(Sobel, self).__init__()
self.edge_conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1, bias=False)
edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
edge_k = np.stack((edge_kx, edge_ky))
edge_k = torch.from_numpy(edge_k).float().view(2, 1, 3, 3)
self.edge_conv.weight = nn.Parameter(edge_k)
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
out = self.edge_conv(x)
out = out.contiguous().view(-1, 2, x.size(2), x.size(3))
return out
#损失函数自带计算logit的操作,无需使用sigmoid/softmax输入映射到[0,1]
def cal_temporal_loss(pre_cls, gt_cls):
return F.binary_cross_entropy_with_logits(pre_cls, gt_cls)
def cal_spatial_loss(output, depth_gt):
losses = []
for depth_index in range(len(output)):
cos = nn.CosineSimilarity(dim=1, eps=0)
get_gradient = Sobel()
ones = torch.ones(depth_gt.size(0), 1, depth_gt.size(2), depth_gt.size(3)).float().cuda()
ones = torch.autograd.Variable(ones)
depth_grad = get_gradient(depth_gt)
output_grad = get_gradient(output)
depth_grad_dx = depth_grad[:,0,:,:,].contiguous().view_as(depth_gt)
depth_grad_dy = depth_grad[:,1,:,:,].contiguous().view_as(depth_gt)
output_grad_dx = output_grad[:,0,:,:,].contiguous().view_as(depth_gt)
output_grad_dy = output_grad[:,0,:,:,].contiguous().view_as(depth_gt)
depth_normal = torch.cat((-depth_grad_dx, -depth_grad_dy, ones), 1)
output_normal = torch.cat((-output_grad_dx, -output_grad_dy, ones), 1)
cof = 0.5
loss_depth = torch.log(torch.abs(output - depth_gt) + cof).mean()
loss_dx = torch.log(torch.abs(output_grad_dx - depth_grad_dx) + cof).mean()
loss_dy = torch.log(torch.abs(output_grad_dy - depth_grad_dy) + cof).mean()
loss_normal = torch.abs(1 - cos(output_normal, depth_normal)).mean()
loss = loss_depth + loss_normal + (loss_dx + loss_dy)
losses.append(loss)
spatial_loss = sum(losses)
return spatial_loss