-
Notifications
You must be signed in to change notification settings - Fork 5
/
inference.py
61 lines (50 loc) · 1.92 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*-coding:utf-8-*-
import cv2
import tensorflow as tf
import numpy as np
import os, time
from config import split_config
from config import merge_config
from model import split
from model import merge
from model.loss import gen_merge_inputs, cal_D, cal_R
if __name__ == '__main__':
Split = split.Split()
Split.load_weights(os.path.join(split_config.saved_models, 'split_499'))
Merge = merge.Merge()
Merge.load_weights(os.path.join(merge_config.saved_models, 'merge_1092'))
img_dir = './test_imgs'
imgs = os.listdir(img_dir)
total_time = 0
for img_name in imgs:
t1 = time.time()
img_path = os.path.join(img_dir, img_name)
img = cv2.imread(img_path)
img = tf.expand_dims(img, axis=0)
img = tf.image.convert_image_dtype(img, tf.float32)
image_batch = tf.image.convert_image_dtype(img, tf.float32)
inputs, grid_loc = gen_merge_inputs(image_batch, Split)
matrix_u2, matrix_u3, matrix_d2, matrix_d3, matrix_l2, matrix_l3, matrix_r2, matrix_r3 = Merge(inputs, grid_loc)
t2 = time.time()
total_time += t2 - t1
# D3, R3 = cal_D(matrix_u3, matrix_d3), cal_R(matrix_l3, matrix_r3)
# D3, R3 = D3.numpy(), R3.numpy()
# print(np.max(D3), np.max(R3))
#
# grid_loc_row, grid_loc_col = grid_loc
# print(image_batch.shape, grid_loc_row, grid_loc_col)
#
# image = image_batch[0].numpy()
#
# h, w, c = image.shape
# for row in grid_loc_row:
# cv2.line(image, (0, row), (w, row), thickness=2, color=(0, 0, 255))
# for col in grid_loc_col:
# cv2.line(image, (col, 0), (col, h), thickness=2, color=(0, 255, 128))
#
# cv2.namedWindow('ori', 0)
# cv2.imshow('ori', image)
# cv2.waitKey()
# cv2.imshow('ori', image)
# cv2.waitKey()
print(total_time / len(imgs))