-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
125 lines (106 loc) · 4.72 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
sys.path.insert(0, 'src')
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from utils import save_img, get_img, exists, list_files
from argparse import ArgumentParser
import time
BATCH_SIZE = 4
DEVICE = '/gpu:0'
# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
assert len(paths_out) > 0
is_paths = type(data_in[0]) == str
if is_paths:
assert len(data_in) == len(paths_out)
img_shape = get_img(data_in[0]).shape
else:
assert data_in.size[0] == len(paths_out)
img_shape = X[0].shape
g = tf.Graph()
batch_size = min(len(paths_out), batch_size)
curr_num = 0
soft_config = tf.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), tf.Session(config=soft_config) as sess:
batch_shape = (batch_size,) + img_shape
img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
preds = transform.net(img_placeholder)
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
num_iters = int(len(paths_out)/batch_size)
for i in range(num_iters):
pos = i * batch_size
curr_batch_out = paths_out[pos:pos+batch_size]
if is_paths:
curr_batch_in = data_in[pos:pos+batch_size]
X = np.zeros(batch_shape, dtype=np.float32)
for j, path_in in enumerate(curr_batch_in):
img = get_img(path_in)
assert img.shape == img_shape
X[j] = img
else:
X = data_in[pos:pos+batch_size]
_preds = sess.run(preds, feed_dict={img_placeholder:X})
for j, path_out in enumerate(curr_batch_out):
save_img(path_out, _preds[j])
remaining_in = data_in[num_iters*batch_size:]
remaining_out = paths_out[num_iters*batch_size:]
if len(remaining_in) > 0:
ffwd(remaining_in, remaining_out, checkpoint_dir, device_t=device_t, batch_size=1)
def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
paths_in, paths_out = [in_path], [out_path]
ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)
def build_parser():
parser = ArgumentParser()
parser.add_argument('--checkpoint', type=str,
dest='checkpoint_dir',help='dir or .ckpt file to load checkpoint from',
metavar='CHECKPOINT', required=True)
parser.add_argument('--in-path', type=str,
dest='in_path',help='dir or file to transform',
metavar='IN_PATH', required=True)
help_out = 'destination (dir or file) of transformed file or files'
parser.add_argument('--out-path', type=str,
dest='out_path', help=help_out, metavar='OUT_PATH',
required=True)
parser.add_argument('--device', type=str,
dest='device',help='device to perform compute on',
metavar='DEVICE', default=DEVICE)
parser.add_argument('--batch-size', type=int,
dest='batch_size',help='batch size for feedforwarding',
metavar='BATCH_SIZE', default=BATCH_SIZE)
return parser
def check_opts(opts):
exists(opts.checkpoint_dir, 'Checkpoint not found!')
exists(opts.in_path, 'In path not found!')
if os.path.isdir(opts.out_path):
exists(opts.out_path, 'out dir not found!')
assert opts.batch_size > 0
def main():
parser = build_parser()
opts = parser.parse_args()
check_opts(opts)
if not os.path.isdir(opts.in_path):
if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
out_path = os.path.join(opts.out_path,os.path.basename(opts.in_path))
else:
out_path = opts.out_path
ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
device=opts.device)
else:
files = list_files(opts.in_path)
full_in = map(lambda x: os.path.join(opts.in_path,x), files)
full_out = map(lambda x: os.path.join(opts.out_path,x), files)
ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
batch_size=opts.batch_size)
if __name__ == '__main__':
main()