-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatasets.py
130 lines (111 loc) · 5.17 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from torch.utils.data import DataLoader, Dataset
import torch
import articulate
from config import paths, joint_set
import config
from utils import normalize_and_concat
import os
# class OwnDatasets(Dataset):
# def __init__(self, filepath, use_joint=[0, 1, 2, 3, 4, 5]):
# super(OwnDatasets, self).__init__()
# data = torch.load(filepath)
# self.use_joint = use_joint
# self.pose = data['pose']
# self.tran = data['tran']
# self.ori = data['ori']
# self.acc = data['acc']
# self.point = data['jp']
# def __getitem__(self, idx):
# nn_pose = self.pose[idx].float()
# if self.tran[idx] is not None:
# tran = self.tran[idx].float()
# ori = self.ori[idx][:, self.use_joint].float()
# acc = self.acc[idx][:, self.use_joint].float()
# joint_pos = self.point[idx].float()
# root_ori = ori[:, -1] # 最后一组为胯部
# imu = normalize_and_concat(acc, ori)
# # 世界速度->本地速度
# if self.tran[idx] is not None:
# velocity = tran
# velocity_local = root_ori.transpose(1, 2).bmm(
# torch.cat((torch.zeros(1, 3), velocity[1:] - velocity[:-1])).unsqueeze(-1)).squeeze(-1) * 60 / config.vel_scale
# else:
# velocity_local = torch.zeros((len(imu), 3))
# # 支撑腿
# stable_threshold = 0.008
# diff = joint_pos - torch.cat((joint_pos[:1], joint_pos[:-1]))
# stable = (diff[:, [7, 8]].norm(dim=2) < stable_threshold).float()
# # 关节位置
# nn_jtr = joint_pos - joint_pos[:, :1]
# leaf_jtr = nn_jtr[:, joint_set.leaf]
# full_jtr = nn_jtr[:, joint_set.full]
# return imu, nn_pose.flatten(1),leaf_jtr.flatten(1), full_jtr.flatten(1), stable, velocity_local, root_ori
# def __len__(self):
# return len(self.ori)
class OwnDatasets(Dataset):
def __init__(self, filepath, use_joint=[0, 1, 2, 3, 4, 5], isMatrix=True, no_norm=False, onlyori=False):
super(OwnDatasets, self).__init__()
data = torch.load(filepath)
self.use_joint = use_joint
self.pose = data['pose']
self.tran = data['tran']
self.ori = data['ori']
self.acc = data['acc']
self.point = data['jp']
self.isMatrix = isMatrix
self.no_norm = no_norm
self.onlyori = onlyori
self.m = articulate.ParametricModel(paths.male_smpl_file)
self.global_to_local_pose = self.m.inverse_kinematics_R
def __getitem__(self, idx):
nn_pose = self.pose[idx].float()
if self.tran[idx] is not None:
tran = self.tran[idx].float()
ori = self.ori[idx][:, self.use_joint].float()
acc = self.acc[idx][:, self.use_joint].float()
joint_pos = self.point[idx].float()
root_ori = ori[:, -1] # 最后一组为胯部
imu = normalize_and_concat(acc, ori, len(self.use_joint), self.isMatrix, self.no_norm, onlyori=self.onlyori )
# 世界速度->本地速度
if self.tran[idx] is not None:
velocity = tran
velocity_local = root_ori.transpose(1, 2).bmm(
torch.cat((torch.zeros(1, 3), velocity[1:] - velocity[:-1])).unsqueeze(-1)).squeeze(-1) * 60 / config.vel_scale
else:
velocity_local = torch.zeros((len(nn_pose), 3))
# 支撑腿
stable_threshold = 0.008
diff = joint_pos - torch.cat((joint_pos[:1], joint_pos[:-1]))
stable = (diff[:, [7, 8]].norm(dim=2) < stable_threshold).float()
# 关节位置
# nn_jtr = joint_pos - joint_pos[:, :1]
# leaf_jtr = nn_jtr[:, joint_set.leaf]
# full_jtr = nn_jtr[:, joint_set.full]
full_pose = self._reduced_glb_6d_to_full_local_mat(root_ori, nn_pose)
pose_global, joint_global = self.m.forward_kinematics(full_pose)
nn_jtr = joint_global - joint_global[:, :1]
leaf_jtr = nn_jtr[:, joint_set.leaf]
full_jtr = nn_jtr[:, joint_set.full]
return imu, nn_pose.flatten(1),leaf_jtr.flatten(1), full_jtr.flatten(1), stable, velocity_local, root_ori
def __len__(self):
return len(self.ori)
def _reduced_glb_6d_to_full_local_mat(self, root_rotation, glb_reduced_pose):
glb_reduced_pose = articulate.math.r6d_to_rotation_matrix(glb_reduced_pose).view(-1, joint_set.n_reduced, 3, 3)
global_full_pose = torch.eye(3, device=glb_reduced_pose.device).repeat(glb_reduced_pose.shape[0], 24, 1, 1)
global_full_pose[:, joint_set.reduced] = glb_reduced_pose
pose = self.global_to_local_pose(global_full_pose).view(-1, 24, 3, 3)
pose[:, joint_set.ignored] = torch.eye(3, device=pose.device)
pose[:, 0] = root_rotation.view(-1, 3, 3)
return pose
if __name__ == "__main__":
dataset = OwnDatasets(os.path.join(paths.dipimu_dir, "veri.pt"))
for imu, nn_pose,leaf_jtr, full_jtr, stable, velocity_local, root_ori in dataset:
print(imu.shape)
print(nn_pose.shape)
print(leaf_jtr.shape)
print(full_jtr.shape)
print(stable.shape)
if velocity_local is not None:
print(velocity_local.shape)
print(root_ori.shape)
break