Skip to content

Commit

Permalink
feat: use observed equations for guesses of observed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 22, 2024
1 parent b5c5b34 commit cd2518e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,16 @@ function generate_initializesystem(sys::ODESystem;
[p for p in parameters(sys) if !haskey(paramsubs, p)]
)

eqs_ics = Symbolics.substitute.([eqs_ics; observed(sys)], (paramsubs,))
# 7) use observed equations for guesses of observed variables if not provided
obseqs = observed(sys)
for eq in obseqs
haskey(defs, eq.lhs) && continue
any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue

defs[eq.lhs] = eq.rhs
end

eqs_ics = Symbolics.substitute.([eqs_ics; obseqs], (paramsubs,))
vars = [vars; collect(values(paramsubs))]
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
Expand Down
8 changes: 8 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,3 +836,11 @@ end
integ = init(prob, Rosenbrock23())
@test integ[y] -0.5
end

@testset "Use observed equations for guesses of observed variables" begin
@variables x(t) y(t) [state_priority = 100]
@mtkbuild sys = ODESystem(
[D(x) ~ x + t, y ~ 2x + 1], t; initialization_eqs = [x^3 + y^3 ~ 1])
isys = ModelingToolkit.generate_initializesystem(sys)
@test isequal(defaults(isys)[y], 2x + 1)
end

0 comments on commit cd2518e

Please sign in to comment.