Skip to content

Commit

Permalink
Merge pull request #2581 from AayushSabharwal/as/mtkparameters-tests
Browse files Browse the repository at this point in the history
test: add MTKParameters tests, fix bugs
  • Loading branch information
ChrisRackauckas authored Mar 27, 2024
2 parents df6b314 + 3b39362 commit 85e1863
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 115 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.11"
SymbolicUtils = "1.0"
Symbolics = "5.24"
Symbolics = "5.26"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ using PrecompileTools, Reexport
VariableSource, getname, variable, Connection, connect,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry,
ticksInState, timeInState
ticksInState, timeInState, fixpoint_sub, fast_substitute
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
tosymbol, lower_varname, diff2term, var_from_nested_derivative,
Expand Down
4 changes: 2 additions & 2 deletions src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module StructuralTransformations
using Setfield: @set!, @set
using UnPack: @unpack

using Symbolics: unwrap, linear_expansion
using Symbolics: unwrap, linear_expansion, fast_substitute
using SymbolicUtils
using SymbolicUtils.Code
using SymbolicUtils.Rewriters
Expand All @@ -23,7 +23,7 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
fast_substitute, get_fullvars, has_equations, observed,
get_fullvars, has_equations, observed,
Schedule

using ModelingToolkit.BipartiteGraphs
Expand Down
8 changes: 4 additions & 4 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ function eq_derivative!(ts::TearingState{ODESystem}, ieq::Int)
end

function tearing_sub(expr, dict, s)
expr = ModelingToolkit.fixpoint_sub(expr, dict)
expr = Symbolics.fixpoint_sub(expr, dict)
s ? simplify(expr) : expr
end

Expand Down Expand Up @@ -439,7 +439,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
order, lv = var_order(iv)
dx = D(simplify_shifts(lower_varname_withshift(
fullvars[lv], idep, order - 1)))
eq = dx ~ simplify_shifts(ModelingToolkit.fixpoint_sub(
eq = dx ~ simplify_shifts(Symbolics.fixpoint_sub(
Symbolics.solve_for(neweqs[ieq],
fullvars[iv]),
total_sub; operator = ModelingToolkit.Shift))
Expand Down Expand Up @@ -467,7 +467,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
@warn "Tearing: solving $eq for $var is singular!"
else
rhs = -b / a
neweq = var ~ ModelingToolkit.fixpoint_sub(
neweq = var ~ Symbolics.fixpoint_sub(
simplify ?
Symbolics.simplify(rhs) : rhs,
total_sub; operator = ModelingToolkit.Shift)
Expand All @@ -481,7 +481,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
if !(eq.lhs isa Number && eq.lhs == 0)
rhs = eq.rhs - eq.lhs
end
push!(alge_eqs, 0 ~ ModelingToolkit.fixpoint_sub(rhs, total_sub))
push!(alge_eqs, 0 ~ Symbolics.fixpoint_sub(rhs, total_sub))
push!(algeeq_idxs, ieq)
end
end
Expand Down
7 changes: 6 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,12 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

SymbolicIndexingInterface.default_values(sys::AbstractSystem) = get_defaults(sys)
function SymbolicIndexingInterface.default_values(sys::AbstractSystem)
return merge(
Dict(eq.lhs => eq.rhs for eq in observed(sys)),
defaults(sys)
)
end

SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false
Expand Down
10 changes: 0 additions & 10 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,3 @@ function observed2graph(eqs, unknowns)

return graph, assigns
end

function fixpoint_sub(x, dict; operator = Nothing)
y = fast_substitute(x, dict; operator)
while !isequal(x, y)
y = x
x = fast_substitute(y, dict; operator)
end

return x
end
1 change: 0 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
trueu0map[var] = defs[root]
end
end
@show trueu0map u0map
if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, trueu0map, parammap)
p = MTKParameters(sys, parammap, trueu0map)
Expand Down
6 changes: 4 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,16 @@ end
function check_index_map(idxmap, sym)
if (idx = get(idxmap, sym, nothing)) !== nothing
return idx
elseif hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
elseif !isa(sym, Symbol) && (!istree(sym) || operation(sym) !== getindex) &&
hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
return idx
end
dsym = default_toterm(sym)
isequal(sym, dsym) && return nothing
if (idx = get(idxmap, dsym, nothing)) !== nothing
idx
elseif hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
elseif !isa(dsym, Symbol) && (!istree(dsym) || operation(dsym) !== getindex) &&
hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
idx
else
nothing
Expand Down
8 changes: 4 additions & 4 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ function SymbolicIndexingInterface.parameter_values(p::MTKParameters, pind::Para
@unpack portion, idx = pind
i, j, k... = idx
if portion isa SciMLStructures.Tunable
return p.tunable[i][j][k...]
return isempty(k) ? p.tunable[i][j] : p.tunable[i][j][k...]
elseif portion isa SciMLStructures.Discrete
return p.discrete[i][j][k...]
return isempty(k) ? p.discrete[i][j] : p.discrete[i][j][k...]
elseif portion isa SciMLStructures.Constants
return p.constant[i][j][k...]
return isempty(k) ? p.constant[i][j] : p.constant[i][j][k...]
elseif portion === DEPENDENT_PORTION
return p.dependent[i][j][k...]
return isempty(k) ? p.dependent[i][j] : p.dependent[i][j][k...]
elseif portion === NONNUMERIC_PORTION
return isempty(k) ? p.nonnumeric[i][j] : p.nonnumeric[i][j][k...]
else
Expand Down
80 changes: 0 additions & 80 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -799,86 +799,6 @@ function fold_constants(ex)
end
end

# Symbolics needs to call unwrap on the substitution rules, but most of the time
# we don't want to do that in MTK.
const Eq = Union{Equation, Inequality}
function fast_substitute(eq::Eq, subs; operator = Nothing)
if eq isa Inequality
Inequality(fast_substitute(eq.lhs, subs; operator),
fast_substitute(eq.rhs, subs; operator),
eq.relational_op)
else
Equation(fast_substitute(eq.lhs, subs; operator),
fast_substitute(eq.rhs, subs; operator))
end
end
function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq}
T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator))
end
function fast_substitute(eqs::AbstractArray, subs; operator = Nothing)
fast_substitute.(eqs, (subs,); operator)
end
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
fast_substitute.(eqs, (subs,); operator)
end
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
fast_substitute(value(expr), subs; operator)
end
end
function fast_substitute(expr, subs; operator = Nothing)
if (_val = get(subs, expr, nothing)) !== nothing
return _val
end
istree(expr) || return expr
op = fast_substitute(operation(expr), subs; operator)
args = SymbolicUtils.unsorted_arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, subs; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)
end
similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end
function fast_substitute(expr, pair::Pair; operator = Nothing)
a, b = pair
isequal(expr, a) && return b
if a isa AbstractArray
for (ai, bi) in zip(a, b)
expr = fast_substitute(expr, ai => bi; operator)
end
end
istree(expr) || return expr
op = fast_substitute(operation(expr), pair; operator)
args = SymbolicUtils.unsorted_arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(args) do x
x′ = fast_substitute(x, pair; operator)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)
end
similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end

normalize_to_differential(s) = s

function restrict_array_to_union(arr)
Expand Down
72 changes: 72 additions & 0 deletions test/mtkparameters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
using SymbolicIndexingInterface
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants

@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
@named sys = ODESystem(
Equation[], t, [], [a, c, d, e, f, g, h], parameter_dependencies = [b => 2a],
continuous_events = [[a ~ 0] => [c ~ 0]], defaults = Dict(a => 0.0))
sys = complete(sys)

ivs = Dict(c => 3a, d => 4, e => [5.0, 6.0, 7.0],
f => ones(Int, 3, 3), g => [0.1, 0.2, 0.3], h => "foo")

ps = MTKParameters(sys, ivs)
@test_nowarn copy(ps)
# dependent initialization, also using defaults
@test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0
@test getp(sys, d)(ps) isa Int

ivs[a] = 1.0
ps = MTKParameters(sys, ivs)
@test_broken getp(sys, g) # SII bug
for (p, val) in ivs
isequal(p, g) && continue # broken
if isequal(p, c)
val = 3ivs[a]
end
idx = parameter_index(sys, p)
# ensure getindex with `ParameterIndex` works
@test ps[idx] == getp(sys, p)(ps) == val
end

# ensure setindex! with `ParameterIndex` works
ps[parameter_index(sys, a)] = 3.0
@test getp(sys, a)(ps) == 3.0
setp(sys, a)(ps, 1.0)

@test getp(sys, a)(ps) == getp(sys, b)(ps) / 2 == getp(sys, c)(ps) / 3 == 1.0

for (portion, values) in [(Tunable(), vcat(ones(9), [1.0, 4.0, 5.0, 6.0, 7.0]))
(Discrete(), [3.0])
(Constants(), [0.1, 0.2, 0.3])]
buffer, repack, alias = canonicalize(portion, ps)
@test alias
@test sort(collect(buffer)) == values
@test all(isone,
canonicalize(portion, SciMLStructures.replace(portion, ps, ones(length(buffer))))[1])
# make sure it is out-of-place
@test sort(collect(buffer)) == values
SciMLStructures.replace!(portion, ps, ones(length(buffer)))
# make sure it is in-place
@test all(isone, canonicalize(portion, ps)[1])
repack(zeros(length(buffer)))
@test all(iszero, canonicalize(portion, ps)[1])
end

setp(sys, a)(ps, 2.0) # test set_parameter!
@test getp(sys, a)(ps) == 2.0

setp(sys, e)(ps, 5ones(3)) # with an array
@test getp(sys, e)(ps) == 5ones(3)

setp(sys, f[2, 2])(ps, 42) # with a sub-index
@test getp(sys, f[2, 2])(ps) == 42

# SII bug
@test_broken setp(sys, g)(ps, ones(100)) # with non-fixed-length array
@test_broken getp(sys, g)(ps) == ones(100)

setp(sys, h)(ps, "bar") # with a non-numeric
@test getp(sys, h)(ps) == "bar"
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ end
@safetestset "Parsing Test" include("variable_parsing.jl")
@safetestset "Simplify Test" include("simplify.jl")
@safetestset "Direct Usage Test" include("direct.jl")
@safetestset "SymbolicIndeingInterface test" include("symbolic_indexing_interface.jl")
@safetestset "System Linearity Test" include("linearity.jl")
@safetestset "Input Output Test" include("input_output_handling.jl")
@safetestset "Clock Test" include("clock.jl")
Expand Down Expand Up @@ -72,6 +71,11 @@ end
end
end

if GROUP == "All" || GROUP == "InterfaceI" || GROUP == "SymbolicIndexingInterface"
@safetestset "SymbolicIndexingInterface test" include("symbolic_indexing_interface.jl")
@safetestset "MTKParameters Test" include("mtkparameters.jl")
end

if GROUP == "All" || GROUP == "InterfaceII"
println("C compilation test requires gcc available in the path!")
@safetestset "C Compilation Test" include("ccompile.jl")
Expand Down
21 changes: 13 additions & 8 deletions test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
using ModelingToolkit, SymbolicIndexingInterface, SciMLBase
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters t a b
@variables x(t)=1.0 y(t)=2.0
D = Differential(t)
@parameters a b
@variables x(t)=1.0 y(t)=2.0 xy(t)
eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
@named odesys = ODESystem(eqs, t, [x, y], [a, b])
@named odesys = ODESystem(eqs, t, [x, y], [a, b]; observed = [xy ~ x + y])

@test all(is_variable.((odesys,), [x, y, 1, 2, :x, :y]))
@test all(.!is_variable.((odesys,), [a, b, t, 3, 0, :a, :b]))
Expand All @@ -24,6 +24,11 @@ eqs = [D(x) ~ a * y + t, D(y) ~ b * t]
@test !isempty(default_values(odesys))
@test default_values(odesys)[x] == 1.0
@test default_values(odesys)[y] == 2.0
@test isequal(default_values(odesys)[xy], x + y)

@named odesys = ODESystem(
eqs, t, [x, y], [a, b]; defaults = [xy => 3.0], observed = [xy ~ x + y])
@test default_values(odesys)[xy] == 3.0

@variables x y z
@parameters σ ρ β
Expand All @@ -36,10 +41,10 @@ eqs = [0 ~ σ * (y - x),
@test !is_time_dependent(ns)

@parameters x
@variables t u(..)
@variables u(..)
Dxx = Differential(x)^2
Dtt = Differential(t)^2
Dt = Differential(t)
Dt = D

#2D PDE
C = 1
Expand All @@ -60,10 +65,10 @@ domains = [t ∈ (0.0, 1.0),
@test pde_system.ps == SciMLBase.NullParameters()
@test parameter_symbols(pde_system) == []

@parameters t x
@parameters x
@constants h = 1
@variables u(..)
Dt = Differential(t)
Dt = D
Dxx = Differential(x)^2
eq = Dt(u(t, x)) ~ h * Dxx(u(t, x))
bcs = [u(0, x) ~ -h * x * (x - 1) * sin(x),
Expand Down

0 comments on commit 85e1863

Please sign in to comment.