Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX Jacobian of from_axis_angle producing NaNs #339

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ConnorTingley
Copy link

@ConnorTingley ConnorTingley commented Jan 15, 2025

Taking the JAX Jacobian of the from_axis_angle function was producing NaNs when vector = [0,0,0]

I found this: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
which lays out how to fix this issue by making the internal function safe when the input is 0.

I just added an extra jnp.where() to check before dividing.


📚 Documentation preview 📚: https://jaxsim--339.org.readthedocs.build//339/

@xela-95
Copy link
Member

xela-95 commented Jan 15, 2025

Thanks a lot for you interest in Jaxsim and you contribution @ConnorTingley!! We will review soon you PR.

Looking a bit at past history of this function, @diegoferigo updated this function in #277 to address the same AD issue, then we updated it more recently in #319, maybe introducing some sort of regression. What do you think @flferretti ?

@CarlottaSartore
Copy link
Contributor

HI @ConnorTingley!, Thanks so much for opening this PR and contributing to JAXSim! We really appreciate you taking the time to dive into this issue and propose a fix.

Would you mind sharing a minimal working example that reproduces the error you're seeing?

When we call from_axis_angle with vector = [0,0,0], it seems to correctly return the identity matrix, as far as we understood your issue is more related to the usage of such a function when computing gradients. Having an example of your error would help us better understand the context and ensure your fix works smoothly for all use cases.

@CarlottaSartore CarlottaSartore self-requested a review January 15, 2025 09:28
@flferretti
Copy link
Collaborator

Hi @ConnorTingley, thanks for reporting the issue! I tried to make an example for which the issue is verified:

import jax
from jaxsim.math import Rotation

vector = jax.numpy.zeros(3)

jac = jax.jacobian(Rotation.from_axis_angle)(vector)

print(jac)

>>> Array([[[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]],

       [[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]]], dtype=float64)

You're right, a matrix of NaNs is produced.

We will work on that to see if the issue is propagated to other parts of the code and will get back to you with a potential low-level solution.

Thanks again for working on this! We'll get back to you ASAP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants