-
Notifications
You must be signed in to change notification settings - Fork 74
/
Copy pathinfer_batch.py
87 lines (62 loc) · 2.62 KB
/
infer_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import csv
import argparse
import tensorflow as tf
import keras.backend as K
from glob import glob
from lib.io import openpose_from_file, read_segmentation, write_mesh
from model.octopus import Octopus
def main(weights, num, batch_file, opt_pose_steps, opt_shape_steps):
K.set_session(tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))))
model = Octopus(num=num)
with open(batch_file, 'r') as f:
reader = csv.reader(f, delimiter=' ')
for name, segm_dir, pose_dir, out_dir in reader:
print('Processing {}...'.format(name))
model.load(weights)
segm_files = sorted(glob(os.path.join(segm_dir, '*.png')))
pose_files = sorted(glob(os.path.join(pose_dir, '*.json')))
if len(segm_files) != len(pose_files) or len(segm_files) == len(pose_files) == 0:
print('> Inconsistent input.')
continue
segmentations = [read_segmentation(f) for f in segm_files]
joints_2d, face_2d = [], []
for f in pose_files:
j, f = openpose_from_file(f)
if len(j) != 25 or len(f) != 70:
print('> Invalid keypoints.')
continue
joints_2d.append(j)
face_2d.append(f)
if opt_pose_steps:
print('> Optimizing for pose...')
model.opt_pose(segmentations, joints_2d, opt_steps=opt_pose_steps)
if opt_shape_steps:
print('> Optimizing for shape...')
model.opt_shape(segmentations, joints_2d, face_2d, opt_steps=opt_shape_steps)
print('> Estimating shape...')
pred = model.predict(segmentations, joints_2d)
write_mesh('{}/{}.obj'.format(out_dir, name), pred['vertices'][0], pred['faces'])
print('> Done.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'batch_file',
type=str,
help="Batch file")
parser.add_argument(
'num',
type=int,
help="Number of views per subject")
parser.add_argument(
'--opt_steps_pose', '-p', default=10, type=int,
help="Optimization steps pose")
parser.add_argument(
'--opt_steps_shape', '-s', default=25, type=int,
help="Optimization steps")
parser.add_argument(
'--weights', '-w',
default='weights/octopus_weights.hdf5',
help='Model weights file (*.hdf5)')
args = parser.parse_args()
main(args.weights, args.num, args.batch_file, args.opt_steps_pose, args.opt_steps_shape)