-
Notifications
You must be signed in to change notification settings - Fork 135
/
model_mesh.py
101 lines (91 loc) · 4.39 KB
/
model_mesh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from lib.utils.utils_smpl import SMPL
from lib.utils.utils_mesh import rotation_matrix_to_angle_axis, rot6d_to_rotmat
class SMPLRegressor(nn.Module):
def __init__(self, args, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.):
super(SMPLRegressor, self).__init__()
param_pose_dim = 24 * 6
self.dropout = nn.Dropout(p=dropout_ratio)
self.fc1 = nn.Linear(num_joints*dim_rep, hidden_dim)
self.pool2 = nn.AdaptiveAvgPool2d((None, 1))
self.fc2 = nn.Linear(num_joints*dim_rep, hidden_dim)
self.bn1 = nn.BatchNorm1d(hidden_dim, momentum=0.1)
self.bn2 = nn.BatchNorm1d(hidden_dim, momentum=0.1)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.head_pose = nn.Linear(hidden_dim, param_pose_dim)
self.head_shape = nn.Linear(hidden_dim, 10)
nn.init.xavier_uniform_(self.head_pose.weight, gain=0.01)
nn.init.xavier_uniform_(self.head_shape.weight, gain=0.01)
self.smpl = SMPL(
args.data_root,
batch_size=64,
create_transl=False,
)
mean_params = np.load(self.smpl.smpl_mean_params)
init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0)
init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0)
self.register_buffer('init_pose', init_pose)
self.register_buffer('init_shape', init_shape)
self.J_regressor = self.smpl.J_regressor_h36m
def forward(self, feat, init_pose=None, init_shape=None):
N, T, J, C = feat.shape
NT = N * T
feat = feat.reshape(N, T, -1)
feat_pose = feat.reshape(NT, -1) # (N*T, J*C)
feat_pose = self.dropout(feat_pose)
feat_pose = self.fc1(feat_pose)
feat_pose = self.bn1(feat_pose)
feat_pose = self.relu1(feat_pose) # (NT, C)
feat_shape = feat.permute(0,2,1) # (N, T, J*C) -> (N, J*C, T)
feat_shape = self.pool2(feat_shape).reshape(N, -1) # (N, J*C)
feat_shape = self.dropout(feat_shape)
feat_shape = self.fc2(feat_shape)
feat_shape = self.bn2(feat_shape)
feat_shape = self.relu2(feat_shape) # (N, C)
pred_pose = self.init_pose.expand(NT, -1) # (NT, C)
pred_shape = self.init_shape.expand(N, -1) # (N, C)
pred_pose = self.head_pose(feat_pose) + pred_pose
pred_shape = self.head_shape(feat_shape) + pred_shape
pred_shape = pred_shape.expand(T, N, -1).permute(1, 0, 2).reshape(NT, -1)
pred_rotmat = rot6d_to_rotmat(pred_pose).view(-1, 24, 3, 3)
pred_output = self.smpl(
betas=pred_shape,
body_pose=pred_rotmat[:, 1:],
global_orient=pred_rotmat[:, 0].unsqueeze(1),
pose2rot=False
)
pred_vertices = pred_output.vertices*1000.0
assert self.J_regressor is not None
J_regressor_batch = self.J_regressor[None, :].expand(pred_vertices.shape[0], -1, -1).to(pred_vertices.device)
pred_joints = torch.matmul(J_regressor_batch, pred_vertices)
pose = rotation_matrix_to_angle_axis(pred_rotmat.reshape(-1, 3, 3)).reshape(-1, 72)
output = [{
'theta' : torch.cat([pose, pred_shape], dim=1), # (N*T, 72+10)
'verts' : pred_vertices, # (N*T, 6890, 3)
'kp_3d' : pred_joints, # (N*T, 17, 3)
}]
return output
class MeshRegressor(nn.Module):
def __init__(self, args, backbone, dim_rep=512, num_joints=17, hidden_dim=2048, dropout_ratio=0.5):
super(MeshRegressor, self).__init__()
self.backbone = backbone
self.feat_J = num_joints
self.head = SMPLRegressor(args, dim_rep, num_joints, hidden_dim, dropout_ratio)
def forward(self, x, init_pose=None, init_shape=None, n_iter=3):
'''
Input: (N x T x 17 x 3)
'''
N, T, J, C = x.shape
feat = self.backbone.get_representation(x)
feat = feat.reshape([N, T, self.feat_J, -1]) # (N, T, J, C)
smpl_output = self.head(feat)
for s in smpl_output:
s['theta'] = s['theta'].reshape(N, T, -1)
s['verts'] = s['verts'].reshape(N, T, -1, 3)
s['kp_3d'] = s['kp_3d'].reshape(N, T, -1, 3)
return smpl_output