forked from Walter0807/MotionBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer_wild.py
97 lines (88 loc) · 3.88 KB
/
infer_wild.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import numpy as np
import argparse
from tqdm import tqdm
import imageio
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lib.utils.tools import *
from lib.utils.learning import *
from lib.utils.utils_data import flip_data
from lib.data.dataset_wild import WildDetDataset
from lib.utils.vismo import render_and_save
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/pose3d/MB_ft_h36m_global_lite.yaml", help="Path to the config file.")
parser.add_argument('-e', '--evaluate', default='checkpoint/pose3d/FT_MB_lite_MB_ft_h36m_global_lite/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)')
parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path')
parser.add_argument('-v', '--vid_path', type=str, help='video path')
parser.add_argument('-o', '--out_path', type=str, help='output path')
parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates')
parser.add_argument('--focus', type=int, default=None, help='target person id')
parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input')
opts = parser.parse_args()
return opts
opts = parse_args()
args = get_config(opts.config)
model_backbone = load_backbone(args)
if torch.cuda.is_available():
model_backbone = nn.DataParallel(model_backbone)
model_backbone = model_backbone.cuda()
print('Loading checkpoint', opts.evaluate)
checkpoint = torch.load(opts.evaluate, map_location=lambda storage, loc: storage)
model_backbone.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos = model_backbone
model_pos.eval()
testloader_params = {
'batch_size': 1,
'shuffle': False,
'num_workers': 8,
'pin_memory': True,
'prefetch_factor': 4,
'persistent_workers': True,
'drop_last': False
}
vid = imageio.get_reader(opts.vid_path, 'ffmpeg')
fps_in = vid.get_meta_data()['fps']
vid_size = vid.get_meta_data()['size']
os.makedirs(opts.out_path, exist_ok=True)
if opts.pixel:
# Keep relative scale with pixel coornidates
wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus)
else:
# Scale to [-1,1]
wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus)
test_loader = DataLoader(wild_dataset, **testloader_params)
results_all = []
with torch.no_grad():
for batch_input in tqdm(test_loader):
N, T = batch_input.shape[:2]
if torch.cuda.is_available():
batch_input = batch_input.cuda()
if args.no_conf:
batch_input = batch_input[:, :, :, :2]
if args.flip:
batch_input_flip = flip_data(batch_input)
predicted_3d_pos_1 = model_pos(batch_input)
predicted_3d_pos_flip = model_pos(batch_input_flip)
predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back
predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2.0
else:
predicted_3d_pos = model_pos(batch_input)
if args.rootrel:
predicted_3d_pos[:,:,0,:]=0 # [N,T,17,3]
else:
predicted_3d_pos[:,0,0,2]=0
pass
if args.gt_2d:
predicted_3d_pos[...,:2] = batch_input[...,:2]
results_all.append(predicted_3d_pos.cpu().numpy())
results_all = np.hstack(results_all)
results_all = np.concatenate(results_all)
render_and_save(results_all, '%s/X3D.mp4' % (opts.out_path), keep_imgs=False, fps=fps_in)
if opts.pixel:
# Convert to pixel coordinates
results_all = results_all * (min(vid_size) / 2.0)
results_all[:,:,:2] = results_all[:,:,:2] + np.array(vid_size) / 2.0
np.save('%s/X3D.npy' % (opts.out_path), results_all)