-
Notifications
You must be signed in to change notification settings - Fork 0
/
rpn.py
120 lines (93 loc) · 4.43 KB
/
rpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
from nms import nms
from tools import bbox_iou, loc2bbox, clip_bbox
from mmcv.cnn import xavier_init
def generate_anchor_in_cell(anchor_scale=[8, 16, 32], anchor_ratio=[0.5, 1., 2.], feature_stride=16):
base_size = feature_stride
anchor = list()
for scale in anchor_scale:
for ratio in anchor_ratio:
w = (base_size * scale) * np.sqrt(ratio)
h = w / ratio
xmin, ymin = base_size / 2 - w / 2, base_size / 2 - h / 2
xmax, ymax = xmin + w, ymin + h
anchor.append([xmin, ymin, xmax, ymax])
return np.array(anchor)
def generate_anchor_in_origin_image(feature_height, feature_width, base_anchor=None, feature_stride=16):
# [xmin, ymin, xmax, ymax]
base_anchor = generate_anchor_in_cell()
grid_x = np.arange(0, feature_stride * feature_width, feature_stride)
grid_y = np.arange(0, feature_stride * feature_height, feature_stride)
grid_x, grid_y = np.meshgrid(grid_x, grid_y)
grid = np.stack((grid_x.flatten(), grid_y.flatten(), grid_x.flatten(), grid_y.flatten()), axis=1)
A = base_anchor.shape[0]
K = grid.shape[0]
anchor = base_anchor.reshape((1, A, 4)) + grid.reshape((1, K, 4)).transpose((1, 0, 2))
anchor = anchor.reshape(K*A, 4)
return anchor
class RPN(nn.Module):
def __init__(self, anchor_scale=[8, 16, 32], anchor_ratio=[0.5, 1., 2.]):
# gt_bboxes_info
super(RPN, self).__init__()
self.anchor_scale = anchor_scale
self.anchor_ratio = anchor_ratio
self.num_anchor = len(self.anchor_scale) * len(self.anchor_ratio) # 9
self.conv = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1) # the size or feature_map didn't change
self.score_layer = nn.Conv2d(in_channels=512, out_channels=self.num_anchor*2, kernel_size=1, stride=1, padding=0)
self.loc_layer = nn.Conv2d(in_channels=512, out_channels=self.num_anchor*4, kernel_size=1, stride=1, padding=0)
#normal_init(self.conv, 0, 0.01)
#normal_init(self.score_layer, 0, 0.01)
#normal_init(self.loc_layer, 0, 0.01)
xavier_init(self.conv)
xavier_init(self.score_layer)
xavier_init(self.loc_layer)
self.min_size = 16
def forward(self, feature_map, feature_stride):
batch_size, channels, height, width = feature_map.shape
anchors = generate_anchor_in_origin_image(feature_height=height, feature_width=width, feature_stride=feature_stride)
mid_layer = F.relu(self.conv(feature_map))
rpn_score = self.score_layer(mid_layer) # (1, 2*9, h, w)
rpn_loc = self.loc_layer(mid_layer) # (1, 4*9 h, w)
rpn_loc = rpn_loc.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 4)
rpn_score = rpn_score.permute(0, 2, 3, 1).contiguous()
rpn_softmax_score = F.softmax(rpn_score.view(batch_size, height, width, self.num_anchor, 2), dim=4)
rpn_fg_score = rpn_softmax_score[:, :, :, :, 1].contiguous()
rpn_fg_score = rpn_fg_score.view(batch_size, -1)
rpn_fg_score = rpn_fg_score[0].cpu().detach().numpy()
rpn_score = rpn_score.view(batch_size, -1, 2)
rois = loc2bbox(anchors, rpn_loc[0].cpu().detach().numpy())
rois = clip_bbox(rois, height*feature_stride, width*feature_stride)
# abandon too small box
ws = rois[:, 2] - rois[:, 0]
hs = rois[:, 3] - rois[:, 1]
keep = np.where((ws >= self.min_size) & (hs >= self.min_size))[0]
rois = rois[keep, :]
rpn_fg_score = rpn_fg_score[keep]
#order the rois
order = rpn_fg_score.ravel().argsort()[::-1]
rois = rois[order, :]
rpn_fg_score = rpn_fg_score[order]
# nms
keep = nms(rois, rpn_fg_score.reshape(-1), thresh=0.5)
rois = rois[keep]
return rois, anchors, rpn_loc, rpn_score
def normal_init(m, mean, stddev, truncated=False):
"""
weight initalizer: truncated normal and random normal.
"""
# x is a parameter
if truncated:
m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
else:
m.weight.data.normal_(mean, stddev)
m.bias.data.zero_()
if __name__ == '__main__':
from pprint import pprint
gt_boxes = torch.Tensor([[0., 0., 150., 150.]])
label = torch.Tensor([1])
input = torch.randn(1, 2048, 40, 40)
model = RPN()
rois = model(input, 16)