Skip to content

Commit

Permalink
Merge pull request #2934 from AayushSabharwal/as/better-pdeps
Browse files Browse the repository at this point in the history
refactor: remove parameter dependencies from MTKParameters
  • Loading branch information
ChrisRackauckas authored Aug 12, 2024
2 parents b1852bb + 97a5c4b commit ba86ee6
Show file tree
Hide file tree
Showing 18 changed files with 196 additions and 274 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.26"
SymbolicIndexingInterface = "0.3.28"
SymbolicUtils = "2.1"
Symbolics = "5.32"
URIs = "1"
Expand Down
4 changes: 3 additions & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps) .∘
wrap_parameter_dependencies(sys, false),
kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
112 changes: 58 additions & 54 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function calculate_hessian end

"""
```julia
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```
Expand All @@ -93,7 +93,7 @@ function generate_tgrad end

"""
```julia
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```
Expand All @@ -104,7 +104,7 @@ function generate_gradient end

"""
```julia
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```
Expand All @@ -115,7 +115,7 @@ function generate_jacobian end

"""
```julia
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```
Expand All @@ -126,7 +126,7 @@ function generate_factorized_W end

"""
```julia
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```
Expand All @@ -137,7 +137,7 @@ function generate_hessian end

"""
```julia
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```
Expand All @@ -148,7 +148,7 @@ function generate_function end
"""
```julia
generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
```
Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
Expand Down Expand Up @@ -187,7 +187,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs),
wrap_array_vars(sys, exprs; dvs) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
else
Expand All @@ -198,7 +199,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs),
wrap_array_vars(sys, exprs; dvs) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
end
Expand All @@ -223,6 +225,10 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
isscalar = !(exprs isa AbstractArray)
Expand Down Expand Up @@ -757,7 +763,7 @@ function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSyste
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return full_parameters(sys)
return parameters(sys)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
Expand Down Expand Up @@ -1214,11 +1220,6 @@ function namespace_guesses(sys)
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
end

function namespace_parameter_dependencies(sys)
pdeps = parameter_dependencies(sys)
Dict(parameters(sys, k) => namespace_expr(v, sys) for (k, v) in pdeps)
end

function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
eqs = equations(sys)
isempty(eqs) && return Equation[]
Expand Down Expand Up @@ -1325,25 +1326,11 @@ function parameters(sys::AbstractSystem)
ps = first.(ps)
end
systems = get_systems(sys)
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if has_parameter_dependencies(sys) &&
(pdeps = parameter_dependencies(sys)) !== nothing
filter(result) do sym
!haskey(pdeps, sym)
end
else
result
end
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
end

function dependent_parameters(sys::AbstractSystem)
if has_parameter_dependencies(sys) &&
!isempty(parameter_dependencies(sys))
collect(keys(parameter_dependencies(sys)))
else
[]
end
return map(eq -> eq.lhs, parameter_dependencies(sys))
end

"""
Expand All @@ -1353,17 +1340,19 @@ Get the parameter dependencies of the system `sys` and its subsystems.
See also [`defaults`](@ref) and [`ModelingToolkit.get_parameter_dependencies`](@ref).
"""
function parameter_dependencies(sys::AbstractSystem)
pdeps = get_parameter_dependencies(sys)
if isnothing(pdeps)
pdeps = Dict()
if !has_parameter_dependencies(sys)
return Equation[]
end
pdeps = get_parameter_dependencies(sys)
systems = get_systems(sys)
isempty(systems) && return pdeps
for subsys in systems
pdeps = merge(pdeps, namespace_parameter_dependencies(subsys))
end
# @info pdeps
return pdeps
# put pdeps after those of subsystems to maintain topological sorted order
return vcat(
reduce(vcat,
[map(eq -> namespace_equation(eq, s), parameter_dependencies(s))
for s in systems];
init = Equation[]),
pdeps
)
end

function full_parameters(sys::AbstractSystem)
Expand Down Expand Up @@ -2317,7 +2306,7 @@ function linearization_function(sys::AbstractSystem, inputs,
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
ps = full_parameters(sys)
ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
Expand Down Expand Up @@ -2420,7 +2409,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
kwargs...)
sts = unknowns(sys)
t = get_iv(sys)
ps = full_parameters(sys)
ps = parameters(sys)
p = reorder_parameters(sys, ps)

fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
Expand Down Expand Up @@ -2852,7 +2841,7 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
eqs = union(get_eqs(basesys), get_eqs(sys))
sts = union(get_unknowns(basesys), get_unknowns(sys))
ps = union(get_ps(basesys), get_ps(sys))
dep_ps = union_nothing(parameter_dependencies(basesys), parameter_dependencies(sys))
dep_ps = union(parameter_dependencies(basesys), parameter_dependencies(sys))
obs = union(get_observed(basesys), get_observed(sys))
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
Expand Down Expand Up @@ -2956,15 +2945,28 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
end

function process_parameter_dependencies(pdeps, ps)
pdeps === nothing && return pdeps, ps
if pdeps isa Vector && eltype(pdeps) <: Pair
pdeps = Dict(pdeps)
elseif !(pdeps isa Dict)
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
if pdeps === nothing || isempty(pdeps)
return Equation[], ps
elseif eltype(pdeps) <: Pair
pdeps = [lhs ~ rhs for (lhs, rhs) in pdeps]
end

if !(eltype(pdeps) <: Equation)
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
end
lhss = BasicSymbolic[]
for p in pdeps
if !isparameter(p.lhs)
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
end
syms = vars(p.rhs)
if !all(isparameter, syms)
error("RHS of parameter dependency must only include parameters. Found $(p.rhs)")
end
push!(lhss, p.lhs)
end
pdeps = topsort_equations(pdeps, union(ps, lhss))
ps = filter(ps) do p
!haskey(pdeps, p)
!any(isequal(p), lhss)
end
return pdeps, ps
end
Expand Down Expand Up @@ -2997,12 +2999,14 @@ function dump_parameters(sys::AbstractSystem)
end
meta
end
pdep_metas = map(collect(keys(pdeps))) do sym
val = pdeps[sym]
pdep_metas = map(pdeps) do eq
sym = eq.lhs
val = eq.rhs
meta = dump_variable_metadata(sym)
defs[eq.lhs] = eq.rhs
meta = merge(meta,
(; dependency = pdeps[sym],
default = symbolic_evaluate(pdeps[sym], merge(defs, pdeps))))
(; dependency = val,
default = symbolic_evaluate(val, defs)))
return meta
end
return vcat(metas, pdep_metas)
Expand Down
25 changes: 16 additions & 9 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
end
expr = build_function(
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps) .∘
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
kwargs...)
if expression == Val{true}
return expr
Expand Down Expand Up @@ -497,7 +498,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps = _ps),
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
wrap_parameter_dependencies(sys, false),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand All @@ -513,7 +515,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
Expand All @@ -524,7 +526,7 @@ generate_rootfinding_callback and thus we can produce a ContinuousCallback inste
"""
function generate_single_rootfinding_callback(
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
end
Expand All @@ -547,7 +549,7 @@ end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
# fuse equations to create VectorContinuousCallback
Expand Down Expand Up @@ -617,7 +619,7 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
total_eqs = sum(num_eqs)
Expand Down Expand Up @@ -660,10 +662,15 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))

if has_index_cache(sys) && get_index_cache(sys) !== nothing
p_inds = [parameter_index(sys, sym) for sym in parameters(affect)]
p_inds = [if (pind = parameter_index(sys, sym)) === nothing
sym
else
pind
end
for sym in parameters(affect)]
else
ps_ind = Dict(reverse(en) for en in enumerate(ps))
p_inds = map(sym -> ps_ind[sym], parameters(affect))
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
end
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
Expand Down Expand Up @@ -711,7 +718,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
end

function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
has_discrete_events(sys) || return nothing
symcbs = discrete_events(sys)
isempty(symcbs) && return nothing
Expand Down
Loading

0 comments on commit ba86ee6

Please sign in to comment.