forked from lisiyao21/Bailando
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_aist_features.py
75 lines (62 loc) · 2.5 KB
/
extract_aist_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import numpy as np
import argparse
from aist_plusplus.loader import AISTDataset
from utils.features.kinetic import extract_kinetic_features
from utils.features.manual_new import extract_manual_features
from smplx import SMPL
import torch
import multiprocessing
import functools
parser = argparse.ArgumentParser(
description='')
parser.add_argument(
'--anno_dir',
type=str,
default='aist-pp-dataset/aist_plusplus_final/',
help='input local dictionary for AIST++ annotations.')
parser.add_argument(
'--smpl_dir',
type=str,
default='smpl/',
help='input local dictionary that stores SMPL data.')
parser.add_argument(
'--save_dir',
type=str,
default='data/aist_features_zero_start/',
help='output local dictionary that stores features.')
FLAGS = parser.parse_args()
def main(seq_name, motion_dir):
print(seq_name)
# Parsing SMPL 24 joints.
# Note here we calculate `transl` as `smpl_trans/smpl_scaling` for
# normalizing the motion in generic SMPL model scale.
smpl = SMPL(model_path=FLAGS.smpl_dir, gender='MALE', batch_size=1)
print (seq_name)
smpl_poses, smpl_scaling, smpl_trans = AISTDataset.load_motion(
motion_dir, seq_name)
keypoints3d = smpl.forward(
global_orient=torch.from_numpy(smpl_poses[:, 0:1]).float(),
body_pose=torch.from_numpy(smpl_poses[:, 1:]).float(),
transl=torch.from_numpy(smpl_trans / smpl_scaling).float(),
).joints.detach().numpy()[:, 0:24, :]
roott = keypoints3d[:1, :1] # the root
keypoints3d = keypoints3d - roott # Calculate relative offset with respect to root
# print(keypoints3d)
features = extract_manual_features(keypoints3d)
np.save(os.path.join(FLAGS.save_dir, 'manual_features_new', seq_name+"_manual.npy"), features)
print (seq_name, "is done")
if __name__ == '__main__':
os.makedirs(FLAGS.save_dir, exist_ok=True)
os.makedirs(os.path.join(FLAGS.save_dir, 'kinetic_features'), exist_ok=True)
os.makedirs(os.path.join(FLAGS.save_dir, 'manual_features_new'), exist_ok=True)
# Parsing data info.
aist_dataset = AISTDataset(FLAGS.anno_dir)
seq_names = aist_dataset.mapping_seq2env.keys()
ignore_list = np.loadtxt(
os.path.join(FLAGS.anno_dir, "ignore_list.txt"), dtype=str).tolist()
seq_names = [n for n in seq_names if n not in ignore_list]
# processing
process = functools.partial(main, motion_dir=aist_dataset.motion_dir)
pool = multiprocessing.Pool(12)
pool.map(process, seq_names)