diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index acfd6518..aeddc2fe 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -366,11 +366,8 @@ def wxyz_to_xyzw(wxyz): def transform_quat_by_quat(v, u): if isinstance(v, torch.Tensor) and isinstance(u, torch.Tensor): assert v.shape == u.shape, f"{v.shape} != {u.shape}" - shape = u.shape - u = u.reshape(-1, 4) - v = v.reshape(-1, 4) - w1, x1, y1, z1 = u[:, 0], u[:, 1], u[:, 2], u[:, 3] - w2, x2, y2, z2 = v[:, 0], v[:, 1], v[:, 2], v[:, 3] + w1, x1, y1, z1 = u[..., 0], u[..., 1], u[..., 2], u[..., 3] + w2, x2, y2, z2 = v[..., 0], v[..., 1], v[..., 2], v[..., 3] ww = (z1 + x1) * (x2 + y2) yy = (w1 - y1) * (w2 + z2) zz = (w1 + y1) * (w2 - z2) @@ -380,15 +377,12 @@ def transform_quat_by_quat(v, u): x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) - quat = torch.stack([w, x, y, z], dim=-1).view(shape) + quat = torch.stack([w, x, y, z], dim=-1) return quat elif isinstance(v, np.ndarray) and isinstance(u, np.ndarray): assert v.shape == u.shape, f"{v.shape} != {u.shape}" - shape = u.shape - u = u.reshape(-1, 4) - v = v.reshape(-1, 4) - w1, x1, y1, z1 = u[:, 0], u[:, 1], u[:, 2], u[:, 3] - w2, x2, y2, z2 = v[:, 0], v[:, 1], v[:, 2], v[:, 3] + w1, x1, y1, z1 = u[..., 0], u[..., 1], u[..., 2], u[..., 3] + w2, x2, y2, z2 = v[..., 0], v[..., 1], v[..., 2], v[..., 3] # This method transforms quat_v by quat_u # This is equivalent to quatmul(quat_u, quat_v) or R_u @ R_v ww = (z1 + x1) * (x2 + y2) @@ -400,7 +394,7 @@ def transform_quat_by_quat(v, u): x = qq - xx + (x1 + w1) * (x2 + w2) y = qq - yy + (w1 - x1) * (y2 + z2) z = qq - zz + (z1 + y1) * (w2 - x2) - quat = np.stack([w, x, y, z], axis=-1).reshape(shape) + quat = np.stack([w, x, y, z], axis=-1) return quat else: gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(u)=}") @@ -639,33 +633,15 @@ def xyz_to_quat(euler_xyz): def transform_by_quat(v, quat): if isinstance(v, torch.Tensor) and isinstance(quat, torch.Tensor): - shape = v.shape - quat = quat.reshape(-1, 4) - v = v.reshape(-1, 3) - qvec = quat[:, 1:] + qvec = quat[..., 1:] t = qvec.cross(v, dim=-1) * 2 - return v + quat[:, :1] * t + qvec.cross(t, dim=-1) + return v + quat[..., :1] * t + qvec.cross(t, dim=-1) elif isinstance(v, np.ndarray) and isinstance(quat, np.ndarray): return transform_by_R(v, quat_to_R(quat)) else: gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(quat)=}") -def transform_by_quat_yaw(v, quat): - if isinstance(v, torch.Tensor) and isinstance(quat, torch.Tensor): - quat_yaw = quat.clone().view(-1, 4) - quat_yaw[:, 1:3] = 0.0 - quat_yaw = normalize(quat_yaw) - return transform_by_quat(v, quat_yaw) - elif isinstance(v, np.ndarray) and isinstance(quat, np.ndarray): - quat_yaw = quat.copy().reshape(-1, 4) - quat_yaw[:, 1:3] = 0.0 - quat_yaw = normalize(quat_yaw) - return transform_by_quat(v, quat_yaw) - else: - gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(quat)=}") - - def axis_angle_to_quat(angle, axis): if isinstance(angle, torch.Tensor) and isinstance(axis, torch.Tensor): theta = (angle / 2).unsqueeze(-1)