-
Notifications
You must be signed in to change notification settings - Fork 55
/
demo_modify.py
128 lines (100 loc) · 4.13 KB
/
demo_modify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import time
import cv2
import numpy as np
from skimage import segmentation
import torch
import torch.nn as nn
class Args(object):
input_image_path = 'image/woof.jpg' # image/coral.jpg image/tiger.jpg
train_epoch = 2 ** 6
mod_dim1 = 64 #
mod_dim2 = 32
gpu_id = 0
min_label_num = 4 # if the label number small than it, break loop
max_label_num = 256 # if the label number small than it, start to show result image.
class MyNet(nn.Module):
def __init__(self, inp_dim, mod_dim1, mod_dim2):
super(MyNet, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(inp_dim, mod_dim1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mod_dim1),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(mod_dim2),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim2, mod_dim1, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(mod_dim1),
nn.ReLU(inplace=True),
nn.Conv2d(mod_dim1, mod_dim2, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(mod_dim2),
)
def forward(self, x):
return self.seq(x)
def run():
start_time0 = time.time()
args = Args()
torch.cuda.manual_seed_all(1943)
np.random.seed(1943)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) # choose GPU:0
image = cv2.imread(args.input_image_path)
'''segmentation ML'''
seg_map = segmentation.felzenszwalb(image, scale=32, sigma=0.5, min_size=64)
# seg_map = segmentation.slic(image, n_segments=10000, compactness=100)
seg_map = seg_map.flatten()
seg_lab = [np.where(seg_map == u_label)[0]
for u_label in np.unique(seg_map)]
'''train init'''
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
tensor = image.transpose((2, 0, 1))
tensor = tensor.astype(np.float32) / 255.0
tensor = tensor[np.newaxis, :, :, :]
tensor = torch.from_numpy(tensor).to(device)
model = MyNet(inp_dim=3, mod_dim1=args.mod_dim1, mod_dim2=args.mod_dim2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=5e-2, momentum=0.9)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-1, momentum=0.0)
image_flatten = image.reshape((-1, 3))
color_avg = np.random.randint(255, size=(args.max_label_num, 3))
show = image
'''train loop'''
start_time1 = time.time()
model.train()
for batch_idx in range(args.train_epoch):
'''forward'''
optimizer.zero_grad()
output = model(tensor)[0]
output = output.permute(1, 2, 0).view(-1, args.mod_dim2)
target = torch.argmax(output, 1)
im_target = target.data.cpu().numpy()
'''refine'''
for inds in seg_lab:
u_labels, hist = np.unique(im_target[inds], return_counts=True)
im_target[inds] = u_labels[np.argmax(hist)]
'''backward'''
target = torch.from_numpy(im_target)
target = target.to(device)
loss = criterion(output, target)
loss.backward()
optimizer.step()
'''show image'''
un_label, lab_inverse = np.unique(im_target, return_inverse=True, )
if un_label.shape[0] < args.max_label_num: # update show
img_flatten = image_flatten.copy()
if len(color_avg) != un_label.shape[0]:
color_avg = [np.mean(img_flatten[im_target == label], axis=0, dtype=np.int) for label in un_label]
for lab_id, color in enumerate(color_avg):
img_flatten[lab_inverse == lab_id] = color
show = img_flatten.reshape(image.shape)
cv2.imshow("seg_pt", show)
cv2.waitKey(1)
print('Loss:', batch_idx, loss.item())
if len(un_label) < args.min_label_num:
break
'''save'''
time0 = time.time() - start_time0
time1 = time.time() - start_time1
print('PyTorchInit: %.2f\nTimeUsed: %.2f' % (time0, time1))
cv2.imwrite("seg_%s_%ds.jpg" % (args.input_image_path[6:-4], time1), show)
if __name__ == '__main__':
run()