Skip to content

Commit

Permalink
Improve caching and dispatch of LinearizingSavingCallback
Browse files Browse the repository at this point in the history
This adds a new type, `LinearizingSavingCallbackCache` and some
sub-types to allow for efficient re-use of memory as the callback
executes over the course of a solve, as well as re-use of that memory in
future solves when operating on a large ensemble simulation.

The top-level `LinearizingSavingCallbackCache` creates thread-safe cache
pool objects that are then used to acquire thread-unsafe cache pool
objects to be used within a single solve.  Those thread-unsafe cache
pool objects can then be released and acquired anew by the next solve.
The thread-unsafe pool objects allow for acquisition of pieces of memory
such as temporary `u` vectors (the recusrive nature of the
`LinearizingSavingCallback` means that we must allocate unknown numbers
of temporary `u` vectors) and chunks of `u` blocks that are then
compacted into a single large matrix in the finalize method of the
callback.  All these pieces of memory are stored within that set of
thread-unsafe caches, and these are released back to the top-level
thread-safe cache pool, for the next solve to acquire and make use of
those pieces of memory in the cache pool.

Using these techniques, the solve time of a large ensemble simulation
with low per-simulation computation has reduced dramatically.  The
simulation solves a butterworth 3rd-order filter circuit over a certain
timespan, swept across different simulus frequencies and circuit
parameters.  The parameter sweep results in a 13500-element ensemble
simulation, that when run  with 8 threads on a M1 Pro takes:

```
48.364827 seconds (625.86 M allocations: 19.472 GiB, 41.81% gc time, 0.17% compilation time)
```

Now, after these caching optimizations, we solve the same ensemble in:
```
13.208123 seconds (166.76 M allocations: 7.621 GiB, 22.21% gc time, 0.61% compilation time)
```

As a side note, the size requirements of the raw linearized solution
data itself is `1.04 GB`.  In general, we expect to allocate somewhere
between 2-3x the final output data to account for temporaries and
inefficient sharing, so while there is still some more work to be done,
this gets us significantly closer to minimal overhead.

This also adds a package extension on `Sundials`, as `IDA` requires that
state vectors are `NVector` types, rather than `Vector{S}` types in
order to not allocate.
  • Loading branch information
staticfloat committed Feb 21, 2024
1 parent be310b8 commit f8f8396
Show file tree
Hide file tree
Showing 7 changed files with 529 additions and 122 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[extensions]
DiffEqCallbacksSundialsExt = "Sundials"

[compat]
Aqua = "0.8"
DataInterpolations = "4"
Expand Down
12 changes: 12 additions & 0 deletions ext/DiffEqCallbacksSundialsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module DiffEqCallbacksSundialsExt

using Sundials: NVector, IDA
import DiffEqCallbacks: solver_state_alloc, solver_state_type

# Allocator; `U` is typically something like `Vector{Float64}`
solver_state_alloc(solver::IDA, U::DataType, num_us::Int) = () -> NVector(U(undef, num_us))

Check warning on line 7 in ext/DiffEqCallbacksSundialsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DiffEqCallbacksSundialsExt.jl#L7

Added line #L7 was not covered by tests

# Type of `solver_state_alloc`, which is just `NVector`
solver_state_type(solver::IDA, U::DataType) = NVector

Check warning on line 10 in ext/DiffEqCallbacksSundialsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DiffEqCallbacksSundialsExt.jl#L10

Added line #L10 was not covered by tests

end # module
2 changes: 1 addition & 1 deletion src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
if dtcache == dt
if integrator.opts.verbose
@warn("Could not restrict values to domain. Iteration was canceled since ",
"proposed time step dt = ", dt," could not be reduced.")
"proposed time step dt = ", dt, " could not be reduced.")
end
break
end
Expand Down
Loading

0 comments on commit f8f8396

Please sign in to comment.