-
Notifications
You must be signed in to change notification settings - Fork 374
/
lovasz_softmax.py
243 lines (207 loc) · 7.79 KB
/
lovasz_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.amp as amp
# grads = {}
##
# version 1: use torch.autograd
class LovaszSoftmaxV1(nn.Module):
'''
This is the autograd version, used in the multi-category classification case
'''
def __init__(self, reduction='mean', ignore_index=-100):
super(LovaszSoftmaxV1, self).__init__()
self.reduction = reduction
self.lb_ignore = ignore_index
def forward(self, logits, label):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LovaszSoftmaxV1()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
n, c, h, w = logits.size()
logits = logits.transpose(0, 1).reshape(c, -1).float() # use fp32 to avoid nan
label = label.view(-1)
idx = label.ne(self.lb_ignore).nonzero(as_tuple=False).squeeze()
probs = logits.softmax(dim=0)[:, idx]
label = label[idx]
lb_one_hot = torch.zeros_like(probs).scatter_(
0, label.unsqueeze(0), 1).detach()
errs = (lb_one_hot - probs).abs()
errs_sort, errs_order = torch.sort(errs, dim=1, descending=True)
n_samples = errs.size(1)
# lovasz extension grad
with torch.no_grad():
# lb_one_hot_sort = lb_one_hot[
# torch.arange(c).unsqueeze(1).repeat(1, n_samples), errs_order
# ].detach()
lb_one_hot_sort = torch.cat([
lb_one_hot[i, ord].unsqueeze(0)
for i, ord in enumerate(errs_order)], dim=0)
n_pos = lb_one_hot_sort.sum(dim=1, keepdim=True)
inter = n_pos - lb_one_hot_sort.cumsum(dim=1)
union = n_pos + (1. - lb_one_hot_sort).cumsum(dim=1)
jacc = 1. - inter / union
if n_samples > 1:
jacc[:, 1:] = jacc[:, 1:] - jacc[:, :-1]
losses = torch.einsum('ab,ab->a', errs_sort, jacc)
if self.reduction == 'sum':
losses = losses.sum()
elif self.reduction == 'mean':
losses = losses.mean()
return losses, errs
##
# version 3: use cuda
import lovasz_softmax_cpp
class LovaszSoftmaxFunctionV3(torch.autograd.Function):
@staticmethod
@amp.custom_fwd(cast_inputs=torch.float32, device_type='cuda')
def forward(ctx, logits, labels, ignore_index):
losses, jacc = lovasz_softmax_cpp.lovasz_softmax_forward(logits,
labels, ignore_index)
ctx.vars = logits, labels, jacc, ignore_index
# grads['one_hot'] = jacc
return losses
@staticmethod
@amp.custom_bwd(device_type='cuda')
def backward(ctx, grad_output):
logits, labels, jacc, ignore_index = ctx.vars
grad = lovasz_softmax_cpp.lovasz_softmax_backward(grad_output, logits, labels, jacc, ignore_index)
return grad, None, None
class LovaszSoftmaxV3(nn.Module):
'''
'''
def __init__(self, reduction='mean', ignore_index=-100):
super(LovaszSoftmaxV3, self).__init__()
self.reduction = reduction
self.lb_ignore = ignore_index
def forward(self, logits, label):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LovaszSoftmaxV3()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
losses = LovaszSoftmaxFunctionV3.apply(logits, label, self.lb_ignore)
if self.reduction == 'sum':
losses = losses.sum()
elif self.reduction == 'mean':
losses = losses.mean()
return losses
if __name__ == '__main__':
torch.manual_seed(123)
torch.cuda.manual_seed(123)
# crit1 = LovaszSoftmaxV1(reduction='none', ignore_index=255)
# crit2 = lovasz_softmax_cpp.lovasz_softmax_forward
#
# bs, c, h, w = 2, 19, 1000, 1000
# # bs, c, h, w = 2, 18, 1240, 1240
# inten = torch.randn(bs, c, h, w).cuda()
# # inten2 = inten1.clone()
# label = torch.randint(0, c, (bs, h, w)).cuda()
# # label[0, :, :] = 255
# # label[1, 13:20, 6] = 255
#
# loss1, errs1, jacc1 = crit1(inten, label)
# loss2, jacc2 = crit2(inten, label, 255)
# print(loss1.size())
# print(loss2.size())
# print((loss1.view(-1) - loss2.view(-1)).abs().sum())
# print((jacc1.view(-1) - jacc2.view(-1)).abs().sum())
# print(loss1)
# print(loss2)
# print((jac1 - jac2).sum())
# print(jac1[1, :8])
# print(jac2[1, :8])
import torchvision
import torch
import numpy as np
import random
torch.manual_seed(15)
random.seed(15)
np.random.seed(15)
torch.backends.cudnn.deterministic = True
scaler = amp.GradScaler()
class Model(nn.Module):
def __init__(self, n_classes):
super(Model, self).__init__()
net = torchvision.models.resnet18(pretrained=False)
self.conv1 = net.conv1
self.bn1 = net.bn1
self.maxpool = net.maxpool
self.relu = net.relu
self.layer1 = net.layer1
self.layer2 = net.layer2
self.layer3 = net.layer3
self.layer4 = net.layer4
self.fc = nn.Conv2d(512, n_classes, 3, 1, 1)
def forward(self, x):
feat = self.conv1(x)
feat = self.bn1(feat)
feat = self.relu(feat)
feat = self.maxpool(feat)
feat = self.layer1(feat)
feat = self.layer2(feat)
feat = self.layer3(feat)
feat = self.layer4(feat)
feat = self.fc(feat)
out = F.interpolate(feat, x.size()[2:], mode='bilinear', align_corners=True)
return out
c = 227
net1 = Model(c)
net2 = Model(c)
net2.load_state_dict(net1.state_dict())
red = 'none'
criteria1 = LovaszSoftmaxV1(reduction='sum', ignore_index=255)
criteria2 = LovaszSoftmaxV3(reduction='sum', ignore_index=255)
net1.cuda()
net2.cuda()
net1.train()
net2.train()
# net1 = net1.half()
# net2 = net2.half()
criteria1.cuda()
criteria2.cuda()
optim1 = torch.optim.SGD(net1.parameters(), lr=1e-2)
optim2 = torch.optim.SGD(net2.parameters(), lr=1e-2)
weight = torch.randn(c).softmax(dim=0).cuda().detach()
bs, h, w = 2, 400, 400
use_fp16 = False
for it in range(1000):
inten = torch.randn(bs, 3, h, w).cuda()#.half()
lbs = torch.randint(0, c, (bs, h, w)).cuda()
# lbs2 = lbs.clone()
# lbs[1, 1, 1] = 255
# lbs[0, 3:100, 2:100] = 255
# lbs[1, 4:70, 28:200] = 255
optim1.zero_grad()
logits1 = net1(inten)
# logits1.retain_grad()
loss1, one_hot = criteria1(logits1, lbs)
loss1 = loss1.mul(weight).sum()
loss1.backward()
optim1.step()
optim2.zero_grad()
logits2 = net2(inten)
loss2 = criteria2(logits2, lbs).mul(weight).sum()
loss2.backward()
optim2.step()
# o1 = one_hot
# o2 = grads['one_hot']
# print((o1 - o2).abs().max())
# print(o1.size())
# print(o2.size())
with torch.no_grad():
if (it+1) % 50 == 0:
print('iter: {}, ================='.format(it+1))
# print(net1.fc.weight.numel())
print('fc weight: ', torch.max(torch.abs(net1.fc.weight - net2.fc.weight)).item())
print('conv1 weight: ', torch.max(torch.abs(net1.conv1.weight - net2.conv1.weight)).item())
print('loss: ', loss1.item() - loss2.item())