-
Notifications
You must be signed in to change notification settings - Fork 0
/
ssd.py
44 lines (41 loc) · 1.42 KB
/
ssd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
"""
Created on Wed May 22 02:44:47 2019
@author: KIIT
"""
import torch
from torch.autograd import Variable
import cv2
from data import BaseTransform,VOC_CLASSES as labelmap
from ssd import build_ssd
import imageio
def detect(frame,net,transform):
height,width=frame.shape[:2]
frame_t=transform(frame)[0]
x=torch.from_numpy(frame_t).permute(2,0,1)
x=Variable(x.unsqueeze(0))
y=net(x)
detections=y.data
scale=torch.Tensor([width,height,width,height])
for i in range(detections.size[1]):
j=0
while(0,i,j,0)>=0.6:
pt=(detections[0,i,j,1:]*scale).numpy()
cv2.rectangle(frame,(int(pt[0]),int(pt[1])),(int(pt[2]),int(pt[3])),(255,0,0),2)
cv2.putText(frame,labelmap[i-1],(int(pt[0]),int(pt[1])),cv2.FONT_HERSHEY_PLAIN,2,(255,255,255),2,cv2.LINE_AA)
j+=1
return frame
#building ssd neural network
net=build_ssd('test')
net.load_state_dict(torch.load('ssd300_mAP_77.43_v2.pth',map_location=lambda storage,loc: storage))
#creating transformation
transform=BaseTransform(net.size,(104/256.0,117/256.0,123/256))
#object detection on video
reader=imageio.get_reader('funny_dog.mp4')
fps=reader.get_meta_data()['fps']
writer=imageio.get_writer('output.mp4',fps=fps)
for i,frame in enumerate(reader):
frame=detect(frame,net.eval(),transform)
writer.append_data(frame)
print(i)
writer.close()