-
Notifications
You must be signed in to change notification settings - Fork 18
/
TorchLossComputer.py
139 lines (98 loc) · 5.44 KB
/
TorchLossComputer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
'''
modifed based on the HR-CNN
https://github.com/radimspetlik/hr-cnn
'''
import math
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import pdb
import torch.nn as nn
# std = 2
def normal_sampling(mean, label_k, std):
return math.exp(-(label_k-mean)**2/(2*std**2))/(math.sqrt(2*math.pi)*std)
def kl_loss(inputs, labels):
criterion = nn.KLDivLoss(reduce=False)
outputs = torch.log(inputs)
loss = criterion(outputs, labels)
#loss = loss.sum()/loss.shape[0]
loss = loss.sum()
return loss
class TorchLossComputer(object):
@staticmethod
def compute_complex_absolute_given_k(output, k, N):
two_pi_n_over_N = Variable(2 * math.pi * torch.arange(0, N, dtype=torch.float), requires_grad=True) / N
hanning = Variable(torch.from_numpy(np.hanning(N)).type(torch.FloatTensor), requires_grad=True).view(1, -1)
k = k.type(torch.FloatTensor).cuda()
two_pi_n_over_N = two_pi_n_over_N.cuda()
hanning = hanning.cuda()
output = output.view(1, -1) * hanning
output = output.view(1, 1, -1).type(torch.cuda.FloatTensor)
k = k.view(1, -1, 1)
two_pi_n_over_N = two_pi_n_over_N.view(1, 1, -1)
complex_absolute = torch.sum(output * torch.sin(k * two_pi_n_over_N), dim=-1) ** 2 \
+ torch.sum(output * torch.cos(k * two_pi_n_over_N), dim=-1) ** 2
return complex_absolute
@staticmethod
def complex_absolute(output, Fs, bpm_range=None):
output = output.view(1, -1)
N = output.size()[1]
unit_per_hz = Fs / N
feasible_bpm = bpm_range / 60.0
k = feasible_bpm / unit_per_hz
# only calculate feasible PSD range [0.7,4]Hz
complex_absolute = TorchLossComputer.compute_complex_absolute_given_k(output, k, N)
return (1.0 / complex_absolute.sum()) * complex_absolute # Analogous Softmax operator
@staticmethod
def cross_entropy_power_spectrum_loss(inputs, target, Fs):
inputs = inputs.view(1, -1)
target = target.view(1, -1)
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
whole_max_idx = whole_max_idx.type(torch.float)
#pdb.set_trace()
#return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
@staticmethod
def cross_entropy_power_spectrum_focal_loss(inputs, target, Fs, gamma):
inputs = inputs.view(1, -1)
target = target.view(1, -1)
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
whole_max_idx = whole_max_idx.type(torch.float)
#pdb.set_trace()
criterion = FocalLoss(gamma=gamma)
#return F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)).view(1), (target.item() - whole_max_idx.item()) ** 2
return criterion(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)
@staticmethod
def cross_entropy_power_spectrum_forward_pred(inputs, Fs):
inputs = inputs.view(1, -1)
bpm_range = torch.arange(40, 190, dtype=torch.float).cuda()
#bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
#bpm_range = torch.arange(40, 260, dtype=torch.float).cuda()
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
whole_max_idx = whole_max_idx.type(torch.float)
return whole_max_idx
@staticmethod
def cross_entropy_power_spectrum_DLDL_softmax2(inputs, target, Fs, std):
target_distribution = [normal_sampling(int(target), i, std) for i in range(140)]
target_distribution = [i if i > 1e-15 else 1e-15 for i in target_distribution]
target_distribution = torch.Tensor(target_distribution).cuda()
#pdb.set_trace()
rank = torch.Tensor([i for i in range(140)]).cuda()
inputs = inputs.view(1, -1)
target = target.view(1, -1)
bpm_range = torch.arange(40, 180, dtype=torch.float).cuda()
complex_absolute = TorchLossComputer.complex_absolute(inputs, Fs, bpm_range)
fre_distribution = F.softmax(complex_absolute.view(-1))
loss_distribution_kl = kl_loss(fre_distribution, target_distribution)
#HR_pre = torch.sum(fre_distribution*rank)
whole_max_val, whole_max_idx = complex_absolute.view(-1).max(0)
whole_max_idx = whole_max_idx.type(torch.float)
return loss_distribution_kl, F.cross_entropy(complex_absolute, target.view((1)).type(torch.long)), torch.abs(target[0] - whole_max_idx)