diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index 58d730ee7..bcbe98a5f 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -68,26 +68,17 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: vector = vector.squeeze() - def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix: + theta = safe_norm(vector) - v = axis - theta = safe_norm(v) + s = jnp.sin(theta) + c = jnp.cos(theta) - s = jnp.sin(theta) - c = jnp.cos(theta) + c1 = 2 * jnp.sin(theta / 2.0) ** 2 - c1 = 2 * jnp.sin(theta / 2.0) ** 2 + safe_theta = jnp.where(theta == 0, 1.0, theta) + u = vector / safe_theta + u = jnp.vstack(u.squeeze()) - u = v / theta - u = jnp.vstack(u.squeeze()) + R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T - R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T - - return R.transpose() - - return jnp.where( - jnp.allclose(vector, 0.0), - # Return an identity rotation matrix when the input vector is zero. - jnp.eye(3), - theta_is_not_zero(axis=vector), - ) + return R.transpose()