-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathfit.py
136 lines (119 loc) · 4.06 KB
/
fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# borrow from optimization https://github.com/wangsen1312/joints2smpl
import os
import argparse
import pickle
import h5py
import natsort
import smplx
import torch
from mld.transforms.joints2rots import config
from mld.transforms.joints2rots.smplify import SMPLify3D
parser = argparse.ArgumentParser()
parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
parser.add_argument("--num_joints", type=int, default=22, help="joint number")
parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
opt = parser.parse_args()
print(opt)
if opt.pkl:
paths = [opt.pkl]
elif opt.dir:
paths = []
file_list = natsort.natsorted(os.listdir(opt.dir))
for item in file_list:
if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
paths.append(os.path.join(opt.dir, item))
else:
raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
for path in paths:
# load joints
if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
print(f"{path} is rendered! skip!")
continue
with open(path, 'rb') as f:
data = pickle.load(f)
joints = data['joints']
# load predefined something
device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
print(config.SMPL_MODEL_DIR)
smplxmodel = smplx.create(
config.SMPL_MODEL_DIR,
model_type="smpl",
gender="neutral",
ext="pkl",
batch_size=joints.shape[0],
).to(device)
# load the mean pose as original
smpl_mean_file = config.SMPL_MEAN_FILE
file = h5py.File(smpl_mean_file, "r")
init_mean_pose = (
torch.from_numpy(file["pose"][:])
.unsqueeze(0).repeat(joints.shape[0], 1)
.float()
.to(device)
)
init_mean_shape = (
torch.from_numpy(file["shape"][:])
.unsqueeze(0).repeat(joints.shape[0], 1)
.float()
.to(device)
)
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
# initialize SMPLify
smplify = SMPLify3D(
smplxmodel=smplxmodel,
batch_size=joints.shape[0],
joints_category=opt.joint_category,
num_iters=opt.num_smplify_iters,
device=device,
)
print("initialize SMPLify3D done!")
print("Start SMPLify!")
keypoints_3d = torch.Tensor(joints).to(device).float()
if opt.joint_category == "AMASS":
confidence_input = torch.ones(opt.num_joints)
# make sure the foot and ankle
if opt.fix_foot:
confidence_input[7] = 1.5
confidence_input[8] = 1.5
confidence_input[10] = 1.5
confidence_input[11] = 1.5
else:
print("Such category not settle down!")
# ----- from initial to fitting -------
(
new_opt_vertices,
new_opt_joints,
new_opt_pose,
new_opt_betas,
new_opt_cam_t,
new_opt_joint_loss,
) = smplify(
init_mean_pose.detach(),
init_mean_shape.detach(),
cam_trans_zero.detach(),
keypoints_3d,
conf_3d=confidence_input.to(device)
)
# fix shape
betas = torch.zeros_like(new_opt_betas)
root = keypoints_3d[:, 0, :]
output = smplxmodel(
betas=betas,
global_orient=new_opt_pose[:, :3],
body_pose=new_opt_pose[:, 3:],
transl=root,
return_verts=True
)
vertices = output.vertices.detach().cpu().numpy()
floor_height = vertices[..., 1].min()
vertices[..., 1] -= floor_height
data['vertices'] = vertices
save_file = path.replace('.pkl', '_mesh.pkl')
with open(save_file, 'wb') as f:
pickle.dump(data, f)
print(f'vertices saved in {save_file}')