Skip to content

Commit

Permalink
Merge pull request #2825 from SciML/eval
Browse files Browse the repository at this point in the history
Make eval great again
  • Loading branch information
ChrisRackauckas authored Jun 28, 2024
2 parents 5f2a594 + adf98ba commit 80def4c
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 69 deletions.
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pages = [
"basics/MTKModel_Connector.md",
"basics/Validation.md",
"basics/DependencyGraphs.md",
"basics/Precompilation.md",
"basics/FAQ.md"],
"System Types" => Any["systems/ODESystem.md",
"systems/SDESystem.md",
Expand Down
117 changes: 117 additions & 0 deletions docs/src/basics/Precompilation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Working with Precompilation and Binary Building

## tl;dr, I just want precompilation to work

The tl;dr is, if you want to make precompilation work then instead of

```julia
ODEProblem(sys, u0, tspan, p)
```

use:

```julia
ODEProblem(sys, u0, tspan, p, eval_module = @__MODULE__, eval_expression = true)
```

As a full example, here's an example of a module that would precompile effectively:

```julia
module PrecompilationMWE
using ModelingToolkit

@variables x(ModelingToolkit.t_nounits)
@named sys = ODESystem([ModelingToolkit.D_nounits(x) ~ -x + 1], ModelingToolkit.t_nounits)
prob = ODEProblem(structural_simplify(sys), [x => 30.0], (0, 100), [],
eval_expression = true, eval_module = @__MODULE__)

end
```

If you use that in your package's code then 99% of the time that's the right answer to get
precompilation working.

## I'm doing something fancier and need a bit more of an explanation

Oh you dapper soul, time for the bigger explanation. Julia's `eval` function evaluates a
function into a module at a specified world-age. If you evaluate a function within a function
and try to call it from within that same function, you will hit a world-age error. This looks like:

```julia
function worldageerror()
f = eval(:((x) -> 2x))
f(2)
end
```

```
julia> worldageerror()
ERROR: MethodError: no method matching (::var"#5#6")(::Int64)
Closest candidates are:
(::var"#5#6")(::Any) (method too new to be called from this world context.)
@ Main REPL[12]:2
```

This is done for many reasons, in particular if the code that is called within a function could change
at any time, then Julia functions could not ever properly optimize because the meaning of any function
or dispatch could always change and you would lose performance by guarding against that. For a full
discussion of world-age, see [this paper](https://arxiv.org/abs/2010.07516).

However, this would be greatly inhibiting to standard ModelingToolkit usage because then something as
simple as building an ODEProblem in a function and then using it would get a world age error:

```julia
function wouldworldage()
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob)
end
```

The reason is because `prob.f` would be constructed via `eval`, and thus `prob.f` could not be called
in the function, which means that no solve could ever work in the same function that generated the
problem. That does mean that:

```julia
function wouldworldage()
prob = ODEProblem(sys, [], (0.0, 1.0))
end
sol = solve(prob)
```

is fine, or putting

```julia
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob)
```

at the top level of a module is perfectly fine too. They just cannot happen in the same function.

This would be a major limitation to ModelingToolkit, and thus we developed
[RuntimeGeneratedFunctions](https://github.com/SciML/RuntimeGeneratedFunctions.jl) to get around
this limitation. It will not be described beyond that, it is dark art and should not be investigated.
But it does the job. But that does mean that it plays... oddly with Julia's compilation.

There are ways to force RuntimeGeneratedFunctions to perform their evaluation and caching within
a given module, but that is not recommended because it does not play nicely with Julia v1.9's
introduction of package images for binary caching.

Thus when trying to make things work with precompilation, we recommend using `eval`. This is
done by simply adding `eval_expression=true` to the problem constructor. However, this is not
a silver bullet because the moment you start using eval, all potential world-age restrictions
apply, and thus it is recommended this is simply used for evaluating at the top level of modules
for the purpose of precompilation and ensuring binaries of your MTK functions are built correctly.

However, there is one caveat that `eval` in Julia works depending on the module that it is given.
If you have `MyPackage` that you are precompiling into, or say you are using `juliac` or PackageCompiler
or some other static ahead-of-time (AOT) Julia compiler, then you don't want to accidentally `eval`
that function to live in ModelingToolkit and instead want to make sure it is `eval`'d to live in `MyPackage`
(since otherwise it will not cache into the binary). ModelingToolkit cannot know that in advance, and thus
you have to pass in the module you wish for the functions to "live" in. This is done via the `eval_module`
argument.

Hence `ODEProblem(sys, u0, tspan, p, eval_module=@__MODULE__, eval_expression=true)` will work if you
are running this expression in the scope of the module you wish to be precompiling. However, if you are
attempting to AOT compile a different module, this means that `eval_module` needs to be appropriately
chosen. And, because `eval_expression=true`, all caveats of world-age apply.
14 changes: 8 additions & 6 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ end
function generate_discrete_affect(
osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
checkbounds = true,
eval_module = @__MODULE__, eval_expression = true)
eval_module = @__MODULE__, eval_expression = false)
@static if VERSION < v"1.7"
error("The `generate_discrete_affect` function requires at least Julia 1.7")
end
Expand Down Expand Up @@ -412,15 +412,17 @@ function generate_discrete_affect(
push!(svs, sv)
end
if eval_expression
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
inits = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), init_funs)
else
affects = map(affect_funs) do a
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, toexpr(LiteralExpr(a))))
end
inits = map(init_funs) do a
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, toexpr(LiteralExpr(a))))
end
else
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
inits = map(a -> toexpr(LiteralExpr(a)), init_funs)
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, inits, clocks, svs, appended_parameters, defaults
Expand Down
51 changes: 26 additions & 25 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
version = nothing, tgrad = false,
jac = false, p = nothing,
t = nothing,
eval_expression = true,
eval_expression = false,
sparse = false, simplify = false,
eval_module = @__MODULE__,
steady_state = false,
Expand All @@ -327,12 +327,12 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
end
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
f_gen = generate_function(sys, dvs, ps; expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
(drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)

f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
f(u, p::Tuple{Vararg{Number}}, t) = f_oop(u, p, t)
Expand All @@ -352,12 +352,12 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
if tgrad
tgrad_gen = generate_tgrad(sys, dvs, ps;
simplify = simplify,
expression = Val{eval_expression},
expression = Val{true},
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
tgrad_oop, tgrad_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
tgrad_gen
tgrad_oop, tgrad_iip = eval_expression ? eval_module.eval.(tgrad_gen) :
(drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, ex)) for ex in tgrad_gen)
if p isa Tuple
__tgrad(u, p, t) = tgrad_oop(u, p..., t)
__tgrad(J, u, p, t) = tgrad_iip(J, u, p..., t)
Expand All @@ -374,12 +374,13 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
if jac
jac_gen = generate_jacobian(sys, dvs, ps;
simplify = simplify, sparse = sparse,
expression = Val{eval_expression},
expression = Val{true},
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
jac_gen
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
(drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, ex)) for ex in jac_gen)

_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
_jac(u, p::Tuple{Vararg{Number}}, t) = jac_oop(u, p, t)
Expand Down Expand Up @@ -474,7 +475,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
ddvs = map(diff2term Differential(get_iv(sys)), dvs),
version = nothing, p = nothing,
jac = false,
eval_expression = true,
eval_expression = false,
sparse = false, simplify = false,
eval_module = @__MODULE__,
checkbounds = false,
Expand All @@ -485,12 +486,11 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
end
f_gen = generate_function(sys, dvs, ps; implicit_dae = true,
expression = Val{eval_expression},
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
(drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
f(du, u, p, t) = f_oop(du, u, p, t)
f(du, u, p::MTKParameters, t) = f_oop(du, u, p..., t)
f(out, du, u, p, t) = f_iip(out, du, u, p, t)
Expand All @@ -499,12 +499,13 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
if jac
jac_gen = generate_dae_jacobian(sys, dvs, ps;
simplify = simplify, sparse = sparse,
expression = Val{eval_expression},
expression = Val{true},
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
jac_gen
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
(drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, ex)) for ex in jac_gen)

_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)
_jac(du, u, p::MTKParameters, ˍ₋gamma, t) = jac_oop(du, u, p..., ˍ₋gamma, t)

Expand Down Expand Up @@ -555,7 +556,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f_oop, f_iip = (drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
f(u, h, p, t) = f_oop(u, h, p, t)
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
Expand All @@ -580,7 +581,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
expression = Val{true},
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f_oop, f_iip = (drop_expr(RuntimeGeneratedFunction(eval_module, eval_module, ex)) for ex in f_gen)
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
isdde = true, kwargs...)
g_oop, g_iip = (drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)
Expand Down Expand Up @@ -770,7 +771,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
checkbounds = false, sparse = false,
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
eval_expression = false,
use_union = true,
tofloat = true,
symbolic_u0 = false,
Expand Down
41 changes: 20 additions & 21 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,21 +407,21 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
ps = parameters(sys),
u0 = nothing;
version = nothing, tgrad = false, sparse = false,
jac = false, Wfact = false, eval_expression = true,
jac = false, Wfact = false, eval_expression = false,
checkbounds = false,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
end
dvs = scalarize.(dvs)

f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen) : f_gen
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{eval_expression},
f_gen = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
f_oop, f_iip = eval_expression ? eval_module.eval.(f_gen) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen)
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
kwargs...)
g_oop, g_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) : g_gen
g_oop, g_iip = eval_expression ? eval_module.eval.(g_gen) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen)

f(u, p, t) = f_oop(u, p, t)
f(u, p::MTKParameters, t) = f_oop(u, p..., t)
Expand All @@ -433,11 +433,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
g(du, u, p::MTKParameters, t) = g_iip(du, u, p..., t)

if tgrad
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{eval_expression},
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{true},
kwargs...)
tgrad_oop, tgrad_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen) :
tgrad_gen
tgrad_oop, tgrad_iip = eval_expression ? eval_module.eval.(tgrad_gen) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen)

_tgrad(u, p, t) = tgrad_oop(u, p, t)
_tgrad(u, p::MTKParameters, t) = tgrad_oop(u, p..., t)
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
Expand All @@ -447,11 +447,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
end

if jac
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{eval_expression},
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{true},
sparse = sparse, kwargs...)
jac_oop, jac_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen) :
jac_gen
jac_oop, jac_iip = eval_expression ? eval_module.eval.(jac_gen) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen)

_jac(u, p, t) = jac_oop(u, p, t)
_jac(u, p::MTKParameters, t) = jac_oop(u, p..., t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
Expand All @@ -463,12 +463,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
if Wfact
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true;
expression = Val{true}, kwargs...)
Wfact_oop, Wfact_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact) :
tmp_Wfact
Wfact_oop_t, Wfact_iip_t = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t) :
tmp_Wfact_t
Wfact_oop, Wfact_iip = eval_expression ? eval_module.eval.(tmp_Wfact) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact)
Wfact_oop_t, Wfact_iip_t = eval_expression ? eval_module.eval.(tmp_Wfact_t) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t)

_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
_Wfact(u, p::MTKParameters, dtgamma, t) = Wfact_oop(u, p..., dtgamma, t)
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)
Expand Down
Loading

0 comments on commit 80def4c

Please sign in to comment.