Skip to content

Commit

Permalink
Allow differentiating through the SoftContacts algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 10, 2023
1 parent 1b26141 commit f4db5dd
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/jaxsim/physics/algos/soft_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def contact_model(
m = tangential_deformation.squeeze()
= jnp.zeros_like(m)

# Note: all the small hardcoded tolerances in this method have been introduced
# to allow jax differentiating through this algorithm. They should not affect
# the accuracy of the simulation, although they might make it less readable.

# ========================
# Normal force computation
# ========================
Expand All @@ -249,7 +253,11 @@ def contact_model(

# Non-linear spring-damper model.
# This is the force magnitude along the direction normal to the terrain.
force_normal_mag = jnp.sqrt(δ) * (K * δ + D * δ̇)
force_normal_mag = jax.lax.select(
pred=δ >= 1e-9,
on_true=jnp.sqrt(δ + 1e-12) * (K * δ + D * δ̇),
on_false=jnp.array(0.0),
)

# Prevent negative normal forces that might occur when δ̇ is largely negative
force_normal_mag = jnp.maximum(0.0, force_normal_mag)
Expand Down Expand Up @@ -304,7 +312,7 @@ def below_terrain():
v_tangential = W_ṗ_C - v_normal

# Compute the tangential force. If inside the friction cone, the contact
f_tangential = -jnp.sqrt(δ) * (K * m + D * v_tangential)
f_tangential = -jnp.sqrt(δ + 1e-12) * (K * m + D * v_tangential)

def sticking_contact():
# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -319,9 +327,17 @@ def sticking_contact():
return CW_f,

def slipping_contact():
# Clip the tangential force if too small, allowing jax to
# differentiate through the norm computation
f_tangential_no_nan = jax.lax.select(
pred=f_tangential.dot(f_tangential) >= 1e-9**2,
on_true=f_tangential,
on_false=jnp.array([1e-12, 0, 0]),
)

# Project the force to the friction cone boundary
f_tangential_projected = (μ * force_normal_mag) * (
f_tangential / jnp.linalg.norm(f_tangential)
f_tangential / jnp.linalg.norm(f_tangential_no_nan)
)

# Sum the normal and tangential forces, and create the 6D force
Expand All @@ -331,18 +347,18 @@ def slipping_contact():
# Correct the material deformation derivative for slipping contacts.
# Basically we compute ṁ such that we get `f_tangential` on the cone
# given the current (m, δ).
ε = 1e-6
α = -K * jnp.sqrt(δ)
ε = 1e-9
δε = jnp.maximum(δ, ε)
βε = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / βε
α = -K * jnp.sqrt(δε)
β = -D * jnp.sqrt(δε)
= (f_tangential_projected - α * m) / β

# Return the 6D force in the contact frame and
# the deformation derivative
return CW_f,

CW_f, = jax.lax.cond(
pred=jnp.linalg.norm(f_tangential) > μ * force_normal_mag,
pred=f_tangential.dot(f_tangential) > (μ * force_normal_mag) ** 2,
true_fun=lambda _: slipping_contact(),
false_fun=lambda _: sticking_contact(),
operand=None,
Expand Down

0 comments on commit f4db5dd

Please sign in to comment.