Skip to content

Commit

Permalink
fix transform dimension issue (#518)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenjia-xu authored Jan 9, 2025
1 parent aa970d0 commit 0a5e8a8
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions genesis/utils/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,8 @@ def wxyz_to_xyzw(wxyz):
def transform_quat_by_quat(v, u):
if isinstance(v, torch.Tensor) and isinstance(u, torch.Tensor):
assert v.shape == u.shape, f"{v.shape} != {u.shape}"
shape = u.shape
u = u.reshape(-1, 4)
v = v.reshape(-1, 4)
w1, x1, y1, z1 = u[:, 0], u[:, 1], u[:, 2], u[:, 3]
w2, x2, y2, z2 = v[:, 0], v[:, 1], v[:, 2], v[:, 3]
w1, x1, y1, z1 = u[..., 0], u[..., 1], u[..., 2], u[..., 3]
w2, x2, y2, z2 = v[..., 0], v[..., 1], v[..., 2], v[..., 3]
ww = (z1 + x1) * (x2 + y2)
yy = (w1 - y1) * (w2 + z2)
zz = (w1 + y1) * (w2 - z2)
Expand All @@ -380,15 +377,12 @@ def transform_quat_by_quat(v, u):
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
quat = torch.stack([w, x, y, z], dim=-1).view(shape)
quat = torch.stack([w, x, y, z], dim=-1)
return quat
elif isinstance(v, np.ndarray) and isinstance(u, np.ndarray):
assert v.shape == u.shape, f"{v.shape} != {u.shape}"
shape = u.shape
u = u.reshape(-1, 4)
v = v.reshape(-1, 4)
w1, x1, y1, z1 = u[:, 0], u[:, 1], u[:, 2], u[:, 3]
w2, x2, y2, z2 = v[:, 0], v[:, 1], v[:, 2], v[:, 3]
w1, x1, y1, z1 = u[..., 0], u[..., 1], u[..., 2], u[..., 3]
w2, x2, y2, z2 = v[..., 0], v[..., 1], v[..., 2], v[..., 3]
# This method transforms quat_v by quat_u
# This is equivalent to quatmul(quat_u, quat_v) or R_u @ R_v
ww = (z1 + x1) * (x2 + y2)
Expand All @@ -400,7 +394,7 @@ def transform_quat_by_quat(v, u):
x = qq - xx + (x1 + w1) * (x2 + w2)
y = qq - yy + (w1 - x1) * (y2 + z2)
z = qq - zz + (z1 + y1) * (w2 - x2)
quat = np.stack([w, x, y, z], axis=-1).reshape(shape)
quat = np.stack([w, x, y, z], axis=-1)
return quat
else:
gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(u)=}")
Expand Down Expand Up @@ -639,33 +633,15 @@ def xyz_to_quat(euler_xyz):

def transform_by_quat(v, quat):
if isinstance(v, torch.Tensor) and isinstance(quat, torch.Tensor):
shape = v.shape
quat = quat.reshape(-1, 4)
v = v.reshape(-1, 3)
qvec = quat[:, 1:]
qvec = quat[..., 1:]
t = qvec.cross(v, dim=-1) * 2
return v + quat[:, :1] * t + qvec.cross(t, dim=-1)
return v + quat[..., :1] * t + qvec.cross(t, dim=-1)
elif isinstance(v, np.ndarray) and isinstance(quat, np.ndarray):
return transform_by_R(v, quat_to_R(quat))
else:
gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(quat)=}")


def transform_by_quat_yaw(v, quat):
if isinstance(v, torch.Tensor) and isinstance(quat, torch.Tensor):
quat_yaw = quat.clone().view(-1, 4)
quat_yaw[:, 1:3] = 0.0
quat_yaw = normalize(quat_yaw)
return transform_by_quat(v, quat_yaw)
elif isinstance(v, np.ndarray) and isinstance(quat, np.ndarray):
quat_yaw = quat.copy().reshape(-1, 4)
quat_yaw[:, 1:3] = 0.0
quat_yaw = normalize(quat_yaw)
return transform_by_quat(v, quat_yaw)
else:
gs.raise_exception(f"both of the inputs must be torch.Tensor or np.ndarray. got: {type(v)=} and {type(quat)=}")


def axis_angle_to_quat(angle, axis):
if isinstance(angle, torch.Tensor) and isinstance(axis, torch.Tensor):
theta = (angle / 2).unsqueeze(-1)
Expand Down

0 comments on commit 0a5e8a8

Please sign in to comment.