Skip to content

Commit

Permalink
Merge pull request #2927 from AayushSabharwal/as/fix-indexing-ci
Browse files Browse the repository at this point in the history
fix: fix downstream indexing tests
  • Loading branch information
ChrisRackauckas authored Aug 6, 2024
2 parents 50a4b12 + 750e82f commit 29040fc
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 303 deletions.
5 changes: 4 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,11 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if idx.portion isa SciMLStructures.Discrete &&
idx.idx[2] == idx.idx[3] == nothing
return nothing
elseif idx.portion isa SciMLStructures.Tunable
return ParameterIndex(
idx.portion, idx.idx[arguments(sym)[(begin + 1):end]...])
else
ParameterIndex(
return ParameterIndex(
idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
end
else
Expand Down
4 changes: 2 additions & 2 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ namespace_affects(::Nothing, s) = nothing
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
SymbolicContinuousCallback(
namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s),
namespace_affects(affect_negs(cb), s))
namespace_affects(affects(cb), s);
affect_neg = namespace_affects(affect_negs(cb), s))
end

"""
Expand Down
138 changes: 0 additions & 138 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,141 +195,3 @@ function split_system(ci::ClockInference{S}) where {S}
end
return tss, inputs, continuous_id, id_to_clock
end

function generate_discrete_affect(
osys::AbstractODESystem, syss, inputs, continuous_id, id_to_clock;
checkbounds = true,
eval_module = @__MODULE__, eval_expression = false)
@static if VERSION < v"1.7"
error("The `generate_discrete_affect` function requires at least Julia 1.7")
end
has_index_cache(osys) && get_index_cache(osys) !== nothing ||
error("Hybrid systems require `split = true`")
out = Sym{Any}(:out)
appended_parameters = full_parameters(syss[continuous_id])
offset = length(appended_parameters)
param_to_idx = Dict{Any, ParameterIndex}(p => parameter_index(osys, p)
for p in appended_parameters)
affect_funs = []
clocks = TimeDomain[]
for (i, (sys, input)) in enumerate(zip(syss, inputs))
i == continuous_id && continue
push!(clocks, id_to_clock[i])
subs = get_substitutions(sys)
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
let_block = Let(assignments, let_body, false)
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
# TODO: filter the needed ones
fullvars = Set{Any}(eq.lhs for eq in observed(sys))
for s in unknowns(sys)
push!(fullvars, s)
end
needed_disc_to_cont_obs = []
disc_to_cont_idxs = ParameterIndex[]
for v in inputs[continuous_id]
_v = arguments(v)[1]
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end

# If the held quantity is calculated through observed
# it will be shifted forward by 1
_v = Shift(get_iv(sys), 1)(_v)
if _v in fullvars
push!(needed_disc_to_cont_obs, _v)
push!(disc_to_cont_idxs, param_to_idx[v])
continue
end
end
append!(appended_parameters, input)
cont_to_disc_obs = build_explicit_observed_function(
osys,
needed_cont_to_disc_obs,
throw = false,
expression = true,
output_type = SVector)
disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs,
throw = false,
expression = true,
output_type = SVector,
op = Shift,
ps = reorder_parameters(osys, appended_parameters))
ni = length(input)
ns = length(unknowns(sys))
disc = Func(
[
out,
DestructuredArgs(unknowns(osys)),
DestructuredArgs.(reorder_parameters(osys, full_parameters(osys)))...,
get_iv(sys)
],
[],
let_block) |> toexpr
cont_to_disc_idxs = [parameter_index(osys, sym) for sym in input]
disc_range = [parameter_index(osys, sym) for sym in unknowns(sys)]
save_expr = :($(SciMLBase.save_discretes!)(integrator, $i))
empty_disc = isempty(disc_range)

# @show disc_to_cont_idxs
# @show cont_to_disc_idxs
# @show disc_range
affect! = :(function (integrator)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
# TODO: find a way to do this without allocating
disc_unknowns = [$(parameter_values)(p, i) for i in $disc_range]
disc = $disc

# Write continuous into to discrete: handles `Sample`
# Write discrete into to continuous
# Update discrete unknowns

# At a tick, c2d must come first
# state update comes in the middle
# d2c comes last
# @show t
# @show "incoming", p
result = c2d_obs(u, p..., t)
for (val, i) in zip(result, $cont_to_disc_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end
$(if !empty_disc
quote
disc(disc_unknowns, u, p..., t)
for (val, i) in zip(disc_unknowns, $disc_range)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end
end
end)
# @show "after c2d", p
# @show "after state update", p
result = d2c_obs(disc_unknowns, p..., t)
for (val, i) in zip(result, $disc_to_cont_idxs)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end

$save_expr

# @show "after d2c", p
discretes, repack, _ = $(SciMLStructures.canonicalize)(
$(SciMLStructures.Discrete()), p)
repack(discretes)
end)

push!(affect_funs, affect!)
end
if eval_expression
affects = map(a -> eval_module.eval(toexpr(LiteralExpr(a))), affect_funs)
else
affects = map(affect_funs) do a
drop_expr(RuntimeGeneratedFunction(
eval_module, eval_module, toexpr(LiteralExpr(a))))
end
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, clocks, appended_parameters, defaults
end
115 changes: 3 additions & 112 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,12 +782,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
ci = infer_clocks!(ClockInference(TearingState(sys)))

if eltype(parammap) <: Pair
parammap = Dict(unwrap(k) => v for (k, v) in todict(parammap))
elseif parammap isa AbstractArray
Expand All @@ -798,38 +792,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
end
end

if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
clockedparammap = Dict()
defs = ModelingToolkit.get_defaults(sys)
for v in ps
v = unwrap(v)
is_discrete_domain(v) || continue
op = operation(v)
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
haskey(parammap, v)
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
end
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
if parammap != SciMLBase.NullParameters() &&
(val = get(parammap, shiftedv, nothing)) !== nothing
clockedparammap[v] = val
elseif op isa Shift
root = arguments(v)[1]
haskey(defs, root) || error("Initial condition for $v not provided.")
clockedparammap[v] = defs[root]
end
end
parammap = if parammap == SciMLBase.NullParameters()
clockedparammap
else
merge(parammap, clockedparammap)
end
end
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(((implicit_dae || !isempty(missingvars)) &&
all(==(Continuous), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys))) && t !== nothing
if eltype(u0map) <: Number
Expand Down Expand Up @@ -1010,29 +975,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
t = tspan !== nothing ? tspan[1] : tspan,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs...)
end
end

kwargs = filter_kwargs(kwargs)
pt = something(get_metadata(sys), StandardODEProblem())

Expand Down Expand Up @@ -1112,40 +1055,14 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p, t) = h_oop(p, t)
h(p::MTKParameters, t) = h_oop(p..., t)
u0 = h(p, tspan[1])

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
end

Expand Down Expand Up @@ -1175,40 +1092,14 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
h(p::MTKParameters, t) = h_oop(p..., t)
h(out, p::MTKParameters, t) = h_iip(out, p..., t)
u0 = h(p, tspan[1])

cbs = process_events(sys; callback, eval_expression, eval_module, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks = ModelingToolkit.generate_discrete_affect(
sys, dss...; eval_expression, eval_module)
discrete_cbs = map(affects, clocks) do affect, clock
@match clock begin
PeriodicClock(dt, _...) => PeriodicCallback(affect, dt;
final_affect = true, initial_affect = true)
&SolverStepClock => DiscreteCallback(Returns(true), affect,
initialize = (c, u, t, integrator) -> affect(integrator))
_ => error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)

kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end

noiseeqs = get_noiseeqs(sys)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
Expand Down
9 changes: 0 additions & 9 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,6 @@ function build_explicit_observed_function(sys, ts;
dep_vars = scalarize(setdiff(vars, ivs))

obs = param_only ? Equation[] : observed(sys)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
# each subsystem is topologically sorted independently. We can append the
# equations to override the `lhs ~ 0` equations in `observed(sys)`
syss, _, continuous_id, _... = dss
for (i, subsys) in enumerate(syss)
i == continuous_id && continue
append!(obs, observed(subsys))
end
end

cs = collect_constants(obs)
if !isempty(cs) > 0
Expand Down
Loading

0 comments on commit 29040fc

Please sign in to comment.