Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Zygote autodiff perfmance #291

Closed
wants to merge 13 commits into from

Conversation

SouthEndMusic
Copy link
Member

Fixes #289.

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

@SouthEndMusic
Copy link
Member Author

Trying the example in the issue I get this error and I can't find out why:

ERROR: BoundsError: attempt to access Float64 at index [2]
Stacktrace:
  [1] macro expansion
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::typeof(throw), args::BoundsError)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:87
  [3] indexed_iterate
    @ .\tuple.jl:101 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(Base.indexed_iterate), ::Float64, ::Int64, ::Nothing)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [5] literal_indexed_iterate
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\tools\builtins.jl:14 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::typeof(Zygote.literal_indexed_iterate), ::Float64, ::Val{2}, ::Nothing)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [7] _interpolate
    @ c:\Users\konin_bt\SciML\DataInterpolations\src\interpolation_methods.jl:4 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::typeof(DataInterpolations._interpolate), ::LinearInterpolation{…}, ::Float64)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [9] AbstractInterpolation
    @ c:\Users\konin_bt\SciML\DataInterpolations\src\DataInterpolations.jl:25 [inlined]
 [10] _pullback(ctx::Zygote.Context{…}, f::LinearInterpolation{…}, args::Float64)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [11] (::Zygote.var"#1365#1370"{Zygote.Context{…}, LinearInterpolation{…}})(x::Float64)
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\lib\broadcast.jl:215
 [12] _broadcast_getindex_evalf
    @ .\broadcast.jl:709 [inlined]
 [13] _broadcast_getindex
    @ .\broadcast.jl:682 [inlined]
 [14] getindex
    @ .\broadcast.jl:636 [inlined]
 [15] copy
    @ .\broadcast.jl:942 [inlined]
 [16] materialize
    @ .\broadcast.jl:903 [inlined]
 [17] _broadcast
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\lib\broadcast.jl:189 [inlined]
 [18] _broadcast_generic
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\lib\broadcast.jl:215 [inlined]
 [19] adjoint
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\lib\broadcast.jl:205 [inlined]
 [20] _pullback
    @ C:\Users\konin_bt\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [21] _apply
    @ .\boot.jl:838 [inlined]
 [22] adjoint
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:203 [inlined]
 [23] _pullback
    @ C:\Users\konin_bt\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [24] broadcasted
    @ .\broadcast.jl:1341 [inlined]
 [25] di_spline
    @ c:\Users\konin_bt\SciML\runner.jl:13 [inlined]
 [26] _pullback(::Zygote.Context{false}, ::typeof(di_spline), ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [27] #31
    @ c:\Users\konin_bt\SciML\runner.jl:17 [inlined]
 [28] _pullback(ctx::Zygote.Context{false}, f::var"#31#32"{Vector{…}, Vector{…}, Vector{…}}, args::Vector{Float64})
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [29] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})  
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:90
 [30] pullback
    @ C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:88 [inlined]
 [31] gradient(f::Function, args::Vector{Float64})
    @ Zygote C:\Users\konin_bt\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:147
 [32] var"##core#402"(y#397::Vector{…}, y#398::Vector{…}, x#399::Vector{…}, x1#400::Vector{…}, y#401::Vector{…})
    @ Main C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:561
 [33] var"##sample#403"(::NTuple{5, Vector{Float64}}, __params::BenchmarkTools.Parameters)
    @ Main C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:570
 [34] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; maxevals::Int64, kwargs::@Kwargs{})
    @ BenchmarkTools C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:187
 [35] _lineartrial(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters)  
    @ BenchmarkTools C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:182
 [36] #invokelatest#2
    @ .\essentials.jl:892 [inlined]
 [37] invokelatest
    @ .\essentials.jl:889 [inlined]
 [38] #lineartrial#46
    @ C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:51 [inlined]
 [39] lineartrial
    @ C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:50 [inlined]
 [40] tune!(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, verbose::Bool, pad::String, kwargs::@Kwargs{})
    @ BenchmarkTools C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:300
 [41] tune!
    @ C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:289 [inlined]
 [42] tune!(b::BenchmarkTools.Benchmark)
    @ BenchmarkTools C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:289
 [43] top-level scope
    @ C:\Users\konin_bt\.julia\packages\BenchmarkTools\QNsku\src\execution.jl:447
Some type information was truncated. Use `show(err)` to see complete types. 

@SouthEndMusic
Copy link
Member Author

SouthEndMusic commented Jul 7, 2024

At least now trying to call Zygote.gradient w.r.t. A.u doesn't throw an error, it just returns (nothing, ) :') @ChrisRackauckas @marcobonici that's probably because an rrule for the constructor is needed?

@SouthEndMusic
Copy link
Member Author

I finally got there for LinearInterpolation 🥳

using DataInterpolations
using Zygote
using Plots
using Random
using ColorSchemes

Random.seed!(3)
pl = plot()

n = 10
t = collect(1.0:10)
u_exact = rand(n)
A_exact = LinearInterpolation(u_exact, t)


t_data = 1 .+ 9rand(10n)
u_data = A_exact.(t_data) .+ (rand(10n) .- 0.5) / 8
scatter!(t_data, u_data; label = "perturbed data")

u_fit = rand(n)

function loss(u_fit)
    A = LinearInterpolation(u_fit, t)
    values = A.(t_data)
    return sum((values - u_data) .^ 2)
end

lr = 1e-2
N = 10

for (i, color) in enumerate(cgrad(:jet, range(0, 1, length = N)))
    ∇loss = only(gradient(loss, u_fit))
    u_fit .-= lr * ∇loss
    loss_it = loss(u_fit, u_data, t_data, t)
    plot!(LinearInterpolation(u_fit, t); color, label = "Iteration $i, loss = $(round(loss_it, digits = 3))")
end

pl

plot

@SouthEndMusic
Copy link
Member Author

@ChrisRackauckas @marcobonici My findings so far:

Trying to calculate the gradient w.r.t. u of the data points now throws an error on the master branch, but I found a way to make it work. Weirdly the error is that a dual number is attempted to be used as an index, as an argument to e.g. linear_interpolation_parameters. I was able to fix this by writing a custom rrule for linear_interpolation_parameters.

Apart from some minor refactors, the main thing I had to do to make the code compatible with Zygote is get rid of ReadOnlyArray. Zygote kept asking for an adjoint of the constructor of ReadOnlyArray, but I couldn't make that work and that also sounds like it would need an upstream fix.

@SouthEndMusic
Copy link
Member Author

Also got it to work for QuadraticSpline and gradients agree with ForwardDiff:

using DataInterpolations
using Zygote
using ForwardDiff
using Plots
using Random
using ColorSchemes

Random.seed!(3)
pl = plot(legendfontsize = 5)
method = QuadraticSpline

n = 10
t = collect(1.0:10)
u_exact = rand(n)
A_exact = method(u_exact, t)


t_data = 1 .+ 9rand(10n)
u_data = A_exact.(t_data) .+ (rand(10n) .- 0.5) / 10
scatter!(t_data, u_data; label = "perturbed data")

u_fit = rand(n)

function loss(u_fit)
    A = method(u_fit, t)
    values = A.(t_data)
    return sum((values - u_data) .^ 2)
end

lr = 1e-3
N = 200
plot_update = 10

for (i, color) in enumerate(cgrad(:jet, range(0, 1, length = N)))
    ∇loss = only(Zygote.gradient(loss, u_fit))
    ∇loss_fd = ForwardDiff.gradient(loss, u_fit)
    u_fit .-= lr * ∇loss
    loss_it = loss(u_fit)
    if i % plot_update == 1
        plot!(
            method(u_fit, t);
            color,
            label = "Iteration $i, loss = $(round(loss_it, digits = 3))",
        )
    end
    @assert ∇loss  ∇loss_fd
end

pl

plot

@SouthEndMusic
Copy link
Member Author

SouthEndMusic commented Jul 9, 2024

@marcobonici this is already an order of magnitude better than what you reported for the gradient in your example:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  200.800 μs …  44.739 ms  ┊ GC (min … max): 0.00% … 98.93%
 Time  (median):     217.100 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   248.329 μs ± 487.808 μs  ┊ GC (mean ± σ):  9.93% ±  9.11%

  █▅▃▂                                                          ▁
  █████▇▆▆▅▄▅▅▃▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▄▅ █
  201 μs        Histogram: log(frequency) by time       1.88 ms <

 Memory estimate: 516.98 KiB, allocs estimate: 3412.

However, it leaves much to be desired. Most time is spent on calculating the cached parameters in the constructor, which it seems Zygote can not handle efficiently.

@SouthEndMusic SouthEndMusic marked this pull request as draft July 13, 2024 07:57
@SouthEndMusic
Copy link
Member Author

SouthEndMusic commented Jul 13, 2024

Have some more speedup:

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  70.500 μs …  42.179 ms  ┊ GC (min … max):  0.00% … 99.54%
 Time  (median):     75.700 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   92.829 μs ± 448.420 μs  ┊ GC (mean ± σ):  14.29% ±  5.78%

  ▃▇███▇▅▅▄▃▂▁▁ ▁ ▁    ▁▁▁▁▂▂▂▂▁                               ▂
  ███████████████████████████████▇▇▆▇▆▆▆▅▆▅▃▆▅▄▆▄▂▅▅▂▄▃▃▄▅▄▅▅▅ █
  70.5 μs       Histogram: log(frequency) by time       144 μs <

 Memory estimate: 196.05 KiB, allocs estimate: 2191.

@SouthEndMusic
Copy link
Member Author

Closed in favor of #315

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improving Zygote performance
1 participant