Skip to content

Commit

Permalink
Merge pull request #2629 from SebastianM-C/extend
Browse files Browse the repository at this point in the history
fix: propagate parameter dependencies in `extend`
  • Loading branch information
ChrisRackauckas authored Apr 12, 2024
2 parents d54f877 + 9c9ecd9 commit 3847ec6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,10 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
eqs = union(get_eqs(basesys), get_eqs(sys))
sts = union(get_unknowns(basesys), get_unknowns(sys))
ps = union(get_ps(basesys), get_ps(sys))
base_deps = get_parameter_dependencies(basesys)
deps = get_parameter_dependencies(sys)
dep_ps = isnothing(base_deps) ? deps :
isnothing(deps) ? base_deps : union(base_deps, deps)
obs = union(get_observed(basesys), get_observed(sys))
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
Expand All @@ -2233,11 +2237,12 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam

if length(ivs) == 0
T(eqs, sts, ps, observed = obs, defaults = defs, name = name, systems = syss,
continuous_events = cevs, discrete_events = devs, gui_metadata = gui_metadata)
continuous_events = cevs, discrete_events = devs, gui_metadata = gui_metadata,
parameter_dependencies = dep_ps)
elseif length(ivs) == 1
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name,
systems = syss, continuous_events = cevs, discrete_events = devs,
gui_metadata = gui_metadata)
gui_metadata = gui_metadata, parameter_dependencies = dep_ps)
end
end

Expand Down Expand Up @@ -2395,8 +2400,8 @@ end
"""
is_diff_equation(eq)
Returns `true` if the input is a differential equation, i.e. is an equatation that contain some
form of differential.
Return `true` if the input is a differential equation, i.e. an equation that contains a
differential term.
Example:
```julia
Expand All @@ -2421,7 +2426,7 @@ end
"""
is_alg_equation(eq)
Returns `true` if the input is an algebraic equation, i.e. is an equatation that does not contain
Return `true` if the input is an algebraic equation, i.e. an equation that does not contain
any differentials.
Example:
Expand Down Expand Up @@ -2603,8 +2608,9 @@ has_alg_eqs(sys::AbstractSystem) = any(is_alg_equation, get_eqs(sys))
"""
has_diff_eqs(sys::AbstractSystem)
For a system, returns true if it contain at least one differential equation (i.e. that contain a
differential) in its *top-level system*.
Return `true` if a system contains at least one differential equation (i.e. an equation with a
differential term). Note that this does not consider subsystems, and only takes into account
equations in the top-level system.
Example:
```julia
Expand Down
18 changes: 18 additions & 0 deletions test/parameter_dependencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ using NonlinearSolve
@test integ.ps[p2] == 10.0
end

@testset "extend" begin
@parameters p1=1.0 p2=1.0
@variables x(t)

@mtkbuild sys1 = ODESystem(
[D(x) ~ p1 * t + p2],
t
)
@named sys2 = ODESystem(
[],
t;
parameter_dependencies = [p2 => 2p1]
)
sys = extend(sys2, sys1)
@test isequal(only(parameters(sys)), p1)
@test Set(full_parameters(sys)) == Set([p1, p2])
end

@testset "Clock system" begin
dt = 0.1
@variables x(t) y(t) u(t) yd(t) ud(t) r(t) z(t)
Expand Down

0 comments on commit 3847ec6

Please sign in to comment.