Skip to content

Commit

Permalink
Merge pull request #213 from JuliaDiffEq/performance
Browse files Browse the repository at this point in the history
speed up MTK OOP vector usage
  • Loading branch information
ChrisRackauckas authored Dec 20, 2019
2 parents 988fde4 + 5bb1631 commit f2cbd38
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 31 deletions.
56 changes: 30 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Each operation builds an `Operation` type, and thus `eqs` is an array of
analyzed by other programs. We can turn this into a `ODESystem` via:

```julia
de = ODESystem(eqs)
de = ODESystem(eqs, t, [x,y,z], [σ,ρ,β])
```

where we tell it the variable types and ordering in the first version, or let it
Expand All @@ -54,49 +54,53 @@ generated code via:

```julia
using MacroTools
myode_oop = generate_function(de, [x,y,z], [σ,ρ,β])[1] # first one is the out-of-place function
myode_oop = generate_function(de)[1] # first one is the out-of-place function
MacroTools.striplines(myode_oop) # print without line numbers

#=
:((u, p, t)->begin
@inbounds begin
X = @inbounds(begin
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
(σ * (y - x), x * (ρ - z) - y, x * y - β * z)
end
end)
end
if u isa Array
return @inbounds(begin
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
[σ * (y - x), x * (ρ - z) - y, x * y - β * z]
end
end)
else
X = @inbounds(begin
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
(σ * (y - x), x * (ρ - z) - y, x * y - β * z)
end
end)
end
T = promote_type(map(typeof, X)...)
convert.(T, X)
map(T, X)
construct = if u isa ModelingToolkit.StaticArrays.StaticArray
ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X))
else
x->begin
du = similar(u, T, 3)
vec(du) .= x
du
convert(typeof(u), x)
end
end
construct(X)
end)
=#

myode_iip = generate_function(de, [x,y,z], [σ,ρ,β])[2] # second one is the in-place function
myode_iip = generate_function(de)[2] # second one is the in-place function
MacroTools.striplines(myode_iip) # print without line numbers

#=
(var"##MTIIPVar#409", u, p, t)->begin
@inbounds begin
@inbounds begin
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
var"##MTIIPVar#409"[1] = σ * (y - x)
var"##MTIIPVar#409"[2] = x * (ρ - z) - y
var"##MTIIPVar#409"[3] = x * y - β * z
end
end
end
nothing
end
:((var"##MTIIPVar#793", u, p, t)->begin
@inbounds begin
@inbounds begin
let (x, y, z, σ, ρ, β) = (u[1], u[2], u[3], p[1], p[2], p[3])
var"##MTIIPVar#793"[1] = σ * (y - x)
var"##MTIIPVar#793"[2] = x * (ρ - z) - y
var"##MTIIPVar#793"[3] = x * y - β * z
end
end
end
nothing
end)
=#
```

Expand Down
15 changes: 10 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,26 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) enumerate(rhss)]
ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))

sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
let_expr = Expr(:let, var_eqs, sys_expr)
tuple_sys_expr = build_expr(:tuple, [conv(rhs) for rhs rhss])
vector_sys_expr = build_expr(:vect, [conv(rhs) for rhs rhss])
let_expr = Expr(:let, var_eqs, tuple_sys_expr)
vector_let_expr = Expr(:let, var_eqs, vector_sys_expr)
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
vector_bounds_block = checkbounds ? vector_let_expr : :(@inbounds begin $vector_let_expr end)
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)

fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))

oop_ex = :(
($(fargs.args...),) -> begin
@inbounds begin
if $(fargs.args[1]) isa Array
return $vector_bounds_block
else
X = $bounds_block
end
T = promote_type(map(typeof,X)...)
convert.(T,X)
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->(du=similar(u, T, $(size(rhss)...)); vec(du) .= x; du)) : constructor)
map(T,X)
construct = $(constructor === nothing ? :(u isa ModelingToolkit.StaticArrays.StaticArray ? ModelingToolkit.StaticArrays.similar_type(typeof(u), eltype(X)) : x->convert(typeof(u),x)) : constructor)
construct(X)
end
)
Expand Down

0 comments on commit f2cbd38

Please sign in to comment.