Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First steps towards automatic differentiation of RBDAs #54

Merged
merged 3 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def derivative(
omega_in_body_fixed: bool = False,
K: float = 0.1,
) -> jtp.Vector:
w = omega.squeeze()
ω = omega.squeeze()
quaternion = quaternion.squeeze()

def Q_body(q: jtp.Vector) -> jtp.Matrix:
Expand Down Expand Up @@ -67,10 +67,20 @@ def Q_inertial(q: jtp.Vector) -> jtp.Matrix:
operand=quaternion,
)

norm_ω = jax.lax.cond(
pred=ω.dot(ω) < (1e-6) ** 2,
true_fun=lambda _: 1e-6,
false_fun=lambda _: jnp.linalg.norm(ω),
operand=None,
)

qd = 0.5 * (
Q
@ jnp.hstack(
[K * jnp.linalg.norm(w) * (1 - jnp.linalg.norm(quaternion)), w]
[
K * norm_ω * (1 - jnp.linalg.norm(quaternion)),
ω,
]
)
)

Expand Down
32 changes: 24 additions & 8 deletions src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def contact_model(
m = tangential_deformation.squeeze()
= jnp.zeros_like(m)

# Note: all the small hardcoded tolerances in this method have been introduced
# to allow jax differentiating through this algorithm. They should not affect
# the accuracy of the simulation, although they might make it less readable.

# ========================
# Normal force computation
# ========================
Expand All @@ -249,7 +253,11 @@ def contact_model(

# Non-linear spring-damper model.
# This is the force magnitude along the direction normal to the terrain.
force_normal_mag = jnp.sqrt(δ) * (K * δ + D * δ̇)
force_normal_mag = jax.lax.select(
pred=δ >= 1e-9,
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
on_false=jnp.array(0.0),
)

# Prevent negative normal forces that might occur when δ̇ is largely negative
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
Expand Down Expand Up @@ -304,7 +312,7 @@ def below_terrain():
v_tangential = W_ṗ_C - v_normal

# Compute the tangential force. If inside the friction cone, the contact
f_tangential = -jnp.sqrt(δ) * (K * m + D * v_tangential)
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)

def sticking_contact():
# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -319,9 +327,17 @@ def sticking_contact():
return CW_f,

def slipping_contact():
# Clip the tangential force if too small, allowing jax to
# differentiate through the norm computation
f_tangential_no_nan = jax.lax.select(
pred=f_tangential.dot(f_tangential) >= 1e-9**2,
on_true=f_tangential,
on_false=jnp.array([1e-12, 0, 0]),
)

# Project the force to the friction cone boundary
f_tangential_projected = (μ * force_normal_mag) * (
f_tangential / jnp.linalg.norm(f_tangential)
f_tangential / jnp.linalg.norm(f_tangential_no_nan)
)

# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -331,18 +347,18 @@ def slipping_contact():
# Correct the material deformation derivative for slipping contacts.
# Basically we compute ṁ such that we get `f_tangential` on the cone
# given the current (m, δ).
ε = 1e-6
α = -K * jnp.sqrt(δ)
ε = 1e-9
δε = jnp.maximum(δ, ε)
βε = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / βε
α = -K * jnp.sqrt(δε)
β = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / β

# Return the 6D force in the contact frame and
# the deformation derivative
return CW_f,

CW_f, = jax.lax.cond(
pred=jnp.linalg.norm(f_tangential) > μ * force_normal_mag,
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
true_fun=lambda _: slipping_contact(),
false_fun=lambda _: sticking_contact(),
operand=None,
Expand Down
190 changes: 190 additions & 0 deletions tests/test_ad_physics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.test_util import check_grads
from pytest import param as p

from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model

from . import utils_models, utils_rng
from .utils_models import Robot


@pytest.mark.parametrize(
"robot, vel_repr",
[
p(*[Robot.Ur10, VelRepr.Inertial], id="Ur10-Inertial"),
p(*[Robot.AnymalC, VelRepr.Inertial], id="AnymalC-Inertial"),
p(*[Robot.Cassie, VelRepr.Inertial], id="Cassie-Inertial"),
],
)
def test_ad_physics(robot: utils_models.Robot, vel_repr: VelRepr) -> None:
"""Unit test of the application of Automatic Differentiation on RBD algorithms."""

robot = Robot.Ur10
vel_repr = VelRepr.Inertial

# Initialize the gravity
gravity = np.array([0, 0, -10.0])

# Get the URDF of the robot
urdf_file_path = utils_models.ModelFactory.get_model_description(robot=robot)

# Build the high-level model
model = Model.build_from_model_description(
model_description=urdf_file_path,
vel_repr=vel_repr,
gravity=gravity,
is_urdf=True,
).mutable(mutable=True, validate=True)

# Initialize the model with a random state
model.data.model_state = utils_rng.random_physics_model_state(
physics_model=model.physics_model
)

# Initialize the model with a random input
model.data.model_input = utils_rng.random_physics_model_input(
physics_model=model.physics_model
)

# ========================
# Extract state and inputs
# ========================

# Extract the physics model used in the low-level physics algorithms
physics_model = model.physics_model

# State
s = model.joint_positions()
= model.joint_velocities()
xfb = model.data.model_state.xfb()

# Inputs
f_ext = model.external_forces()
tau = model.joint_generalized_forces_targets()

# Perturbation used for computing finite differences
ε = jnp.finfo(jnp.array(0.0)).resolution ** (1 / 3)

# =====================================================
# Check first-order and second-order derivatives of ABA
# =====================================================

import jaxsim.physics.algos.aba

aba = lambda xfb, s, , tau, f_ext: jaxsim.physics.algos.aba.aba(
model=physics_model, xfb=xfb, q=s, qd=, tau=tau, f_ext=f_ext
)

check_grads(
f=aba,
args=(xfb, s, , tau, f_ext),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ======================================================
# Check first-order and second-order derivatives of RNEA
# ======================================================

import jaxsim.physics.algos.rnea

W_v̇_WB = utils_rng.get_rng().uniform(size=6, low=-1)
= utils_rng.get_rng().uniform(size=physics_model.dofs(), low=-1)

rnea = lambda xfb, s, , , W_v̇_WB, f_ext: jaxsim.physics.algos.rnea.rnea(
model=physics_model, xfb=xfb, q=s, qd=, qdd=, a0fb=W_v̇_WB, f_ext=f_ext
)

check_grads(
f=rnea,
args=(xfb, s, , , W_v̇_WB, f_ext),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ======================================================
# Check first-order and second-order derivatives of CRBA
# ======================================================

import jaxsim.physics.algos.crba

crba = lambda s: jaxsim.physics.algos.crba.crba(model=physics_model, q=s)

check_grads(
f=crba,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ====================================================
# Check first-order and second-order derivatives of FK
# ====================================================

import jaxsim.physics.algos.forward_kinematics

fk = (
lambda xfb, s: jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
model=physics_model, xfb=xfb, q=s
)
)

check_grads(
f=fk,
args=(xfb, s),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# ==========================================================
# Check first-order and second-order derivatives of Jacobian
# ==========================================================

import jaxsim.physics.algos.jacobian

link_indices = [l.index() for l in model.links()]

jacobian = lambda s: jaxsim.physics.algos.jacobian.jacobian(
model=physics_model, q=s, body_index=link_indices[-1]
)

check_grads(
f=jacobian,
args=(s,),
order=2,
modes=["rev", "fwd"],
eps=ε,
)

# =====================================================================
# Check first-order and second-order derivatives of soft contacts model
# =====================================================================

import jaxsim.physics.algos.soft_contacts

p = utils_rng.get_rng().uniform(size=3, low=-1)
v = utils_rng.get_rng().uniform(size=3, low=-1)
m = utils_rng.get_rng().uniform(size=3, low=-1)

parameters = jaxsim.physics.algos.soft_contacts.SoftContactsParams.build(
K=10_000, D=20.0, mu=0.5
)

soft_contacts = lambda p, v, m: jaxsim.physics.algos.soft_contacts.SoftContacts(
parameters=parameters
).contact_model(position=p, velocity=v, tangential_deformation=m)

check_grads(
f=soft_contacts,
args=(p, v, m),
order=2,
modes=["rev", "fwd"],
eps=ε,
)
Loading