diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py index 7c743125d..ef1db1651 100644 --- a/src/jaxsim/high_level/model.py +++ b/src/jaxsim/high_level/model.py @@ -960,22 +960,27 @@ def coriolis_matrix(self) -> jtp.Matrix: xfb=self.data.model_state.xfb(), ) - if self.velocity_representation is VelRepr.Inertial: - return C - - elif self.velocity_representation is VelRepr.Body: + def body(C): W_H_B = self.base_transform() B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() - return B_X_W.T @ C @ B_X_W + C = B_X_W.T @ C @ B_X_W - elif self.velocity_representation is VelRepr.Mixed: + def mixed(C): W_H_B = self.base_transform() W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - return BW_X_W.T @ C @ BW_X_W + C = BW_X_W.T @ C @ BW_X_W - else: - raise ValueError(self.velocity_representation) + to_active = { + VelRepr.Body: body, + VelRepr.Mixed: mixed, + VelRepr.Inertial: lambda x: x, + } + + try: + C = to_active[self.velocity_representation](C) + except ValueError as e: + raise e return M, M_dot, C