Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mooncake Backend doesn't handle functions with StaticArrays output #642

Open
jmurphy6895 opened this issue Nov 25, 2024 · 4 comments · Fixed by #643
Open

Mooncake Backend doesn't handle functions with StaticArrays output #642

jmurphy6895 opened this issue Nov 25, 2024 · 4 comments · Fixed by #643
Labels
backend Related to one or more autodiff backends

Comments

@jmurphy6895
Copy link

If a function returns a static array or vector, the AutoMooncake backend errors

using StaticArrays
using DifferentiationInterface
using ForwardDiff
using Mooncake

function MWE(x::AbstractVector)

    z = SVector{3}(x.^2)

    return z

end

test = rand(3)

f_ad, df_ad = value_and_jacobian(
    MWE,
    AutoMooncake(; config=nothing),
    test
)

f_ad, df_ad = value_and_jacobian(
    MWE,
    AutoForwardDiff(),
    test
)

Gives the error

ERROR: MethodError: no method matching copyto!(::Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}, ::SVector{3, Float64})
The function copyto! exists, but no method is defined for this combination of argument types.

Closest candidates are:
copyto!(::IndexStyle, ::AbstractArray, ::IndexStyle, ::AbstractArray)
@ Base abstractarray.jl:1064
copyto!(::Zygote.Buffer, ::Any)
@ Zygote C:\Users\jmurp.julia\packages\Zygote\nyzjS\src\tools\buffer.jl:54
copyto!(::PermutedDimsArray, ::AbstractArray)
@ Base permuteddimsarray.jl:295
...

Stacktrace:
[1] copyto!!(dst::Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}, src::SVector{3, Float64})
@ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\DifferentiationInterfaceMooncakeExt.jl:29
[2] value_and_pullback(::Function, ::DifferentiationInterfaceMooncakeExt.MooncakeOneArgPullbackPrep{…}, ::AutoMooncake{…}, ::Vector{…}, ::Tuple{…})
@ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\onearg.jl:35
[3] prepare_pullback(::Function, ::AutoMooncake{Nothing}, ::Vector{Float64}, ::Tuple{SVector{3, Float64}})
@ DifferentiationInterfaceMooncakeExt C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\ext\DifferentiationInterfaceMooncakeExt\onearg.jl:22
[4] _prepare_jacobian_aux(::DifferentiationInterface.PushforwardSlow, ::DifferentiationInterface.BatchSizeSettings{…}, ::SVector{…}, ::Tuple{…}, ::AutoMooncake{…}, ::Vector{…})
@ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\first_order\jacobian.jl:167
[5] prepare_jacobian(::typeof(MWE), ::AutoMooncake{Nothing}, ::Vector{Float64})
@ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\first_order\jacobian.jl:108
[6] value_and_jacobian(::typeof(MWE), ::AutoMooncake{Nothing}, ::Vector{Float64})
@ DifferentiationInterface C:\Users\jmurp.julia\packages\DifferentiationInterface\gSdHF\src\fallbacks\no_prep.jl:60
[7] top-level scope
@ c:\Users\jmurp.julia\dev\SatelliteToolboxGravityModels\test\differentiability.jl:69
Some type information was truncated. Use show(err) to see complete types.

@gdalle gdalle added bug backend Related to one or more autodiff backends labels Nov 25, 2024
@gdalle
Copy link
Member

gdalle commented Nov 25, 2024

Thanks for reporting this! Can you try it out with the branch from #643 ?

@jmurphy6895
Copy link
Author

Thanks for the quick reply! It looks like it's still giving the same error on my end

@gdalle
Copy link
Member

gdalle commented Nov 26, 2024

Oh right, I had misread the error. Modulo my hotfix, I now think this happens because DI expects an array but Mooncake wraps it into a Tangent. @willtebbutt any idea how we should handle this?

@willtebbutt
Copy link
Member

willtebbutt commented Nov 26, 2024

So this looks to me like it's happening on

dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
, when we prepare the DI-friendly gradient to be passed into the reverse-pass of Mooncake.

If I'm not mistaken, we just need to a add a method of copy_to!! which knows how to translate a static array into a Tangent. There's enough information in the call to do this. It would just be something like

function copyto!!(dst::T, src::SVector{3, Float64}) where {T<:Mooncake.Tangent{@NamedTuple{data::Tuple{Float64, Float64, Float64}}}}
    return T((data = getfield(src, :data), ))
end

in this case.

We'll need translation rules like this for most types so, thinking forwards to more general types, we'll never have a complete solution to this translation problem unless DI places some restrictions on the set of types that users are permitted to work with. My understanding is that you're not keen to restrict users in this way, so probably the best thing to do is to define a catch-all method of copyto!! which throws a (more informative version of) an error message saying something like "we didn't expect to you pass this type, so we don't have a conversion rule for it. Please open an issue."

That being said, it might be that we can do something more general which says "the thing I'm trying to copy_to!! to is a Tangent, therefore I just need to recursively pull out the fields of src and build NamedTuples out of them", and do a similar thing for MutableTangent but in-place. This might be functionality that I should provide as part of the tangent interface in Mooncake.

Either way, since static arrays are something people are quite interested in currently, I would suggest just adding a conversion rule for this case, and punting the more general fix down the line -- I'll open an issue on Mooncake which references this issue.

Additionally, note that the thing that Mooncake will return in this instance is another Tangent (assuming that the argument to the function being differentiated is itself a static array), which is probably not what users want. It might be worth thinking a bit about whether we want to apply some specific translation rules to make the types "more user friendly" in a uniform way. Or I could think a bit about how to do something predictable on Mooncake's end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend Related to one or more autodiff backends
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants