Skip to content

Commit

Permalink
Fix some minor display issues; Add prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
DTheLegend committed Feb 10, 2023
1 parent e3a1d6c commit c5c1e1b
Show file tree
Hide file tree
Showing 7 changed files with 530 additions and 144 deletions.
10 changes: 9 additions & 1 deletion cli/fcos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,23 @@
train_parser.add_argument('--save_file', type=pathlib.Path, required=False)

predict_parser = sub_parsers.add_parser("predict")
predict_parser.add_argument('model', type=pathlib.Path)
predict_parser.add_argument('weights', type=pathlib.Path)
predict_parser.add_argument('classfile', type=pathlib.Path)
predict_parser.add_argument('image', type=pathlib.Path)

args = parser.parse_args()

if args.command == "train":
try:
from fcos.train import train
del args.command

train(**vars(args))
except ImportError:
print("Train Module not included.")
elif args.command == "predict":
pass
from fcos.cli import predict
del args.command

predict.main(**vars(args))
169 changes: 34 additions & 135 deletions cli/fcos/cli/predict.py
Original file line number Diff line number Diff line change
@@ -1,145 +1,44 @@
from importlib.resources import read_text, open_binary
from fcos.core.models import FCOS
from fcos.core.loaders import ClassLoader
from fcos.core.data_augmentation import preprocessing
from fcos.core.mAP.functions import fcos_to_boxes
import numpy as np
import torch
import xml.dom.minidom
import cv2
import fcos.map_function as mf
from fcos.DataLoader import FolderData
import torch.utils.data as Data
import fcos.get_image as get_image
import fcos.module
import fcos.net
from fcos.net import FCOS

def prediction(confs, locs, centers, row, col):
# Find Classes.
try:
f = read_text(__package__, 'classes.txt')
classes = f.splitlines()
except FileNotFoundError:
print("classes.txt file was not found...")
exit(0)

iou_lime = 0.5 # threshold for iou
cls_lime = 0.2 # threshold for confidence

# obtain the size of all the feature maps
map_sizes = []
for map_num in range(len(confs)):
# obtain the size of the feature map
H = confs[map_num].size(2)
W = confs[map_num].size(3)
map_sizes.append([H, W])
# initialize a manager for feature maps
map_master = mf.Map_master(map_sizes)

# initialize a list for storing predicted bounding boxes of different classes
GTmaster = []
for i in classes:
GTmaster.append([])

# traverse all feature maps
for feature_num in range(len(confs)):
conf = confs[feature_num].detach().cpu()
loc = locs[feature_num].detach().cpu()
center = centers[feature_num].detach().cpu()
# suppress confidence
conf = conf * center
# obtain non-background area
indexes = torch.max(conf, 1)[1]
indexes = indexes.numpy().tolist()[0]
# search for pixels on the feature map whose confidence are over threshold
for i in range(len(indexes)):
for j in range(len(indexes[i])):
# the pixel is considered as positive sample if its confidence is larger than the threshold
if conf[0, indexes[i][j], i, j] >= cls_lime:
box = [feature_num, i, j, indexes[i][j], conf[0, indexes[i][j], i, j], loc[0, 0, i, j],
loc[0, 1, i, j], loc[0, 2, i, j], loc[0, 3, i, j]]
box = map_master.decode_coordinate(box, row, col)
GTmaster[indexes[i][j]].append(box)
# initialize a empty list for returning the final detected bounding boxes after NMS
boxes = []
# non maximum suppression (NMS)
for GT in GTmaster:
while len(GT) > 0:
max_obj = []
for obj in GT[:]:
# obtain the bounding box with the highest confidence within the same category
if max_obj == []:
max_obj = obj
continue
if max_obj[1] < obj[1]:
max_obj = obj
GT.remove(max_obj)
# select the bounding box of the highest confidence as a final predicted box
boxes.append(max_obj)
if len(GT) > 0:
# remove other boxes of the same category whose iou between it and the selected box is larger than the threshold
for obj in GT[:]:
# calculate the iou between it and the selected bounding box
iou = mf.compute_iou([obj[2], obj[3], obj[4], obj[5]],
[max_obj[2], max_obj[3], max_obj[4], max_obj[5]])
if iou > iou_lime:
# delete it when the iou breaks the threshold
GT.remove(obj)
return boxes

def main():
def main(model, weights, classfile, image):
# load class list
print("balls")

f = read_text('fcos', 'classes.txt')

classes = f.splitlines()
print("balls")
classes = ClassLoader(classfile)

# load the model
with open_binary(fcos.module, 'net0.unpkl') as f:
net = FCOS()
net.load_state_dict(torch.load(f))
net.eval()
# load test set
test_set = FolderData("./src/Drake/src/fcos/DataSet/labels/test/")
loader = Data.DataLoader(
dataset=test_set, # torch TensorDataset format
batch_size=1, # mini batch size
shuffle=True, # shuffle the daatset
num_workers=2, # read data by multi threads
)

# detect
for step, label_paths in enumerate(loader):
# read one image
xml_path = label_paths[0]
# read annotation file
dom = xml.dom.minidom.parse(xml_path)
# obtain root of the xml file
root = dom.documentElement
objects = root.getElementsByTagName("object")
path = root.getElementsByTagName('path')[0]
# obtain the path of the image
pathname = "./src/Drake/src/fcos/" + path.childNodes[0].data
print(pathname)
# read the image
frame = cv2.imread(pathname)
model = FCOS(torch.load(model))
model.load_state_dict(torch.load(weights))
train_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(train_device)
model.eval()

row = frame.shape[0]
col = frame.shape[1]
torch_images, labels = get_image.get_label(label_paths)
# predict
confs, locs, centers = net(torch_images)
boxes = prediction(confs, locs, centers, row, col)
for box in boxes:
xmin = box[2]
ymin = box[3]
xmax = box[4]
ymax = box[5]
# draw rectangle
frame = cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (0, 40, 255), 2)
frame = cv2.putText(frame, classes[box[0]] + ":" + str(round(box[1].item(), 2)), (xmin, ymin - 5), cv2.FONT_HERSHEY_COMPLEX, 0.8,
(0, 40, 255), 1)

cv2.imwrite(f'detections/detections_{step}.png', frame)
# obtain the path of the image
frame = cv2.imread(str(image.resolve()))

# torch digestive
mcvities = preprocessing(torch.from_numpy(np.transpose(frame, (2, 0, 1)))).unsqueeze(0)
mcvities = mcvities.to(train_device)

if __name__ == '__main__':
main()
row = frame.shape[0]
col = frame.shape[1]
# predict
confs, locs, centers = model(mcvities)
boxes = fcos_to_boxes(classes, confs, locs, centers, row, col)
for box in boxes:
xmin = box[2] * col // 480
ymin = box[3] * row // 360
xmax = box[4] * col // 480
ymax = box[5] * row // 360
# draw rectangle
frame = cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (0, 40, 255), 2)
frame = cv2.putText(frame, classes[box[0]] + ":" + str(round(box[1].item(), 2)), (xmin, ymin - 5), cv2.FONT_HERSHEY_COMPLEX, 0.8,
(0, 40, 255), 1)

cv2.imshow(f'WOW!', frame)
cv2.waitKey(0) & 0xFF == ord('q')
2 changes: 1 addition & 1 deletion cli/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.extras]
train = ["fcos-train"]
train = ["fcos-train"]
File renamed without changes.
Loading

0 comments on commit c5c1e1b

Please sign in to comment.