Skip to content

Commit

Permalink
Merge pull request #346 from SciML/dds_mtk
Browse files Browse the repository at this point in the history
Updating MTK interface
  • Loading branch information
pogudingleb authored Aug 19, 2024
2 parents d6d9043 + 1115554 commit af0abd1
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 78 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/SpellCheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ jobs:
uses: actions/checkout@v4
- name: Check spelling
uses: crate-ci/[email protected]
with:
files: ./src ./docs
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ PrecompileTools = "1.2"
Primes = "0.5"
Random = "1.6, 1.7"
SpecialFunctions = "2"
SymbolicUtils = "2"
Symbolics = "5.30.1"
SymbolicUtils = "2, 3"
Symbolics = "5.30.1, 6"
Test = "1.6, 1.7"
TestSetExtensions = "2"
TimerOutputs = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/discrete_time.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ eqs = [
R(k) ~ R(k - 1) + α * I(k - 1),
]
@mtkbuild sys = DiscreteSystem(eqs, t)
@named sys = DiscreteSystem(eqs, t)
assess_local_identifiability(sys, measured_quantities = [I])
```
Expand Down
111 changes: 47 additions & 64 deletions ext/ModelingToolkitSIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ function StructuralIdentifiability.eval_at_nemo(e::SymbolicUtils.BasicSymbolic,
return args[1]^args[2]
end
return 1 // args[1]^(-args[2])
# dirty way, assumes that all shifts should be just removed
elseif startswith(String(Symbol(Symbolics.operation(e))), "Shift")
return args[1]
end
throw(Base.ArgumentError("Function $(Symbolics.operation(e)) is not supported"))
elseif e isa Symbolics.Symbolic
Expand All @@ -71,12 +74,11 @@ function StructuralIdentifiability.eval_at_nemo(
end

function get_measured_quantities(ode::ModelingToolkit.ODESystem)
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
return filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(ode),
)
outputs = filter(eq -> ModelingToolkit.isoutput(eq.lhs), ModelingToolkit.equations(ode))
if !isempty(outputs)
return outputs
elseif !isempty(ModelingToolkit.observed(ode))
return ModelingToolkit.observed(ode)
else
throw(
error(
Expand All @@ -103,6 +105,9 @@ function StructuralIdentifiability.mtk_to_si(
de::ModelingToolkit.AbstractTimeDependentSystem,
measured_quantities::Array{ModelingToolkit.Equation},
)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(de)
end
return __mtk_to_si(
de,
[(replace(string(e.lhs), "(t)" => ""), e.rhs) for e in measured_quantities],
Expand Down Expand Up @@ -153,6 +158,20 @@ function preprocess_ode(
return mtk_to_si(de, measured_quantities)
end

#------------------------------------------------------------------------------
function clean_calls(funcs)
res = []
for f in funcs
if length(Symbolics.arguments(f)) == 1 &&
!Symbolics.iscall(first(Symbolics.arguments(f)))
push!(res, f)
else
push!(res, first(Symbolics.arguments(f)))
end
end
return res
end

#------------------------------------------------------------------------------
"""
function __mtk_to_si(de::ModelingToolkit.AbstractTimeDependentSystem, measured_quantities::Array{Tuple{String, SymbolicUtils.BasicSymbolic}})
Expand Down Expand Up @@ -186,11 +205,10 @@ function __mtk_to_si(
end

y_functions = [each[2] for each in measured_quantities]
inputs = filter(v -> ModelingToolkit.isinput(v), ModelingToolkit.unknowns(de))
state_vars = filter(
s -> !(ModelingToolkit.isinput(s) || ModelingToolkit.isoutput(s)),
ModelingToolkit.unknowns(de),
)
state_vars =
filter(s -> !ModelingToolkit.isoutput(s), clean_calls(map(e -> e.lhs, diff_eqs)))
all_funcs = collect(Set(clean_calls(ModelingToolkit.unknowns(de))))
inputs = filter(s -> !ModelingToolkit.isoutput(s), setdiff(all_funcs, state_vars))
params = ModelingToolkit.parameters(de)
t = ModelingToolkit.arguments(diff_eqs[1].lhs)[1]
params_from_measured_quantities = union(
Expand Down Expand Up @@ -240,7 +258,7 @@ function __mtk_to_si(
end
# -----------------------------------------------------------------------------
"""
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=Array{}[], prob_threshold::Float64=0.99, type=:SE, loglevel=Logging.Info)
Input:
- `ode` - the ODESystem object from ModelingToolkit
Expand All @@ -263,7 +281,7 @@ The return value is a tuple consisting of the array of bools and the number of e
"""
function StructuralIdentifiability.assess_local_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = Array{}[],
prob_threshold::Float64 = 0.99,
type = :SE,
Expand All @@ -288,28 +306,13 @@ end
prob_threshold::Float64 = 0.99,
type = :SE,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(ode))
@info "Measured quantities are not provided, trying to find the outputs in input ODE."
measured_quantities = filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(ode),
)
else
throw(
error(
"Measured quantities (output functions) were not provided and no outputs were found.",
),
)
end
end
if length(funcs_to_check) == 0
funcs_to_check = vcat(
[e for e in ModelingToolkit.unknowns(ode) if !ModelingToolkit.isoutput(e)],
ModelingToolkit.parameters(ode),
)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
@info "System parsed into $ode"
conversion_back = Dict(v => k for (k, v) in conversion)
if isempty(funcs_to_check)
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
end

funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]

if isequal(type, :SE)
Expand Down Expand Up @@ -340,7 +343,7 @@ end
# ------------------------------------------------------------------------------

"""
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=ModelingToolkit.Equation[], funcs_to_check=[], known_ic=[], prob_threshold = 0.99, loglevel=Logging.Info)
Input:
- `ode` - the ModelingToolkit.ODESystem object that defines the model
Expand All @@ -356,7 +359,7 @@ If known initial conditions are provided, the identifiability results for the st
"""
function StructuralIdentifiability.assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
Expand All @@ -376,16 +379,13 @@ end

function _assess_identifiability(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
funcs_to_check = [],
known_ic = [],
prob_threshold = 0.99,
)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(ode)
end

ode, conversion = mtk_to_si(ode, measured_quantities)
@info "System parsed into $ode"
conversion_back = Dict(v => k for (k, v) in conversion)
if isempty(funcs_to_check)
funcs_to_check = [conversion_back[x] for x in [ode.x_vars..., ode.parameters...]]
Expand Down Expand Up @@ -470,43 +470,29 @@ function _assess_local_identifiability(
known_ic = Array{}[],
prob_threshold::Float64 = 0.99,
)
if length(measured_quantities) == 0
if any(ModelingToolkit.isoutput(eq.lhs) for eq in ModelingToolkit.equations(dds))
@info "Measured quantities are not provided, trying to find the outputs in input dynamical system."
measured_quantities = filter(
eq -> (ModelingToolkit.isoutput(eq.lhs)),
ModelingToolkit.equations(dds),
)
else
throw(
error(
"Measured quantities (output functions) were not provided and no outputs were found.",
),
)
end
end

# Converting the finite difference operator in the right-hand side to
# the corresponding shift operator
eqs = filter(eq -> !(ModelingToolkit.isoutput(eq.lhs)), ModelingToolkit.equations(dds))

dds_aux_ode, conversion = mtk_to_si(dds, measured_quantities)
dds_aux = StructuralIdentifiability.DDS{QQMPolyRingElem}(dds_aux_ode)
@info "Parsed into the following model: $dds_aux"
if length(funcs_to_check) == 0
params = parameters(dds)
params_from_measured_quantities = union(
[filter(s -> !iscall(s), get_variables(y)) for y in measured_quantities]...,
)
funcs_to_check = vcat(
[
x for x in unknowns(dds) if
x for x in clean_calls(unknowns(dds)) if
conversion[x] in StructuralIdentifiability.x_vars(dds_aux)
],
union(params, params_from_measured_quantities),
)
end
funcs_to_check_ = [eval_at_nemo(x, conversion) for x in funcs_to_check]
known_ic_ = [eval_at_nemo(x, conversion) for x in known_ic]
@info "Functions to check are $(["$f" for f in funcs_to_check_]) and initial conditions are known for $(["$f" for f in known_ic_])"

result = StructuralIdentifiability._assess_local_identifiability_discrete_aux(
dds_aux,
Expand Down Expand Up @@ -568,7 +554,7 @@ find_identifiable_functions(de, measured_quantities = [y1 ~ x0])
"""
function StructuralIdentifiability.find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
measured_quantities = ModelingToolkit.Equation[],
known_ic = [],
prob_threshold::Float64 = 0.99,
seed = 42,
Expand All @@ -595,18 +581,15 @@ end

function _find_identifiable_functions(
ode::ModelingToolkit.ODESystem;
measured_quantities = Array{ModelingToolkit.Equation}[],
known_ic = Array{Symbolics.Num}[],
measured_quantities = ModelingToolkit.Equation[],
known_ic = Symbolics.Num[],
prob_threshold::Float64 = 0.99,
seed = 42,
with_states = false,
simplify = :standard,
rational_interpolator = :VanDerHoevenLecerf,
)
Random.seed!(seed)
if isempty(measured_quantities)
measured_quantities = get_measured_quantities(ode)
end
ode, conversion = mtk_to_si(ode, measured_quantities)
known_ic_ = [eval_at_nemo(each, conversion) for each in known_ic]
result = nothing
Expand Down
Loading

0 comments on commit af0abd1

Please sign in to comment.