Skip to content

Commit

Permalink
Merge pull request #2834 from AayushSabharwal/as/sde-sarray
Browse files Browse the repository at this point in the history
fix: infer oop form for SDEProblem/SDEFunction with StaticArrays
  • Loading branch information
ChrisRackauckas authored Jul 3, 2024
2 parents f3b040d + 54df3cc commit d64f973
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
42 changes: 36 additions & 6 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
if isdde
eqs = delay_to_function(sys, eqs)
end
if eqs isa AbstractMatrix && isdiag(eqs)
eqs = diag(eqs)
end
u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps)
Expand Down Expand Up @@ -403,14 +406,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0)
checks = false)
end

function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),
function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(sys),
ps = parameters(sys),
u0 = nothing;
version = nothing, tgrad = false, sparse = false,
jac = false, Wfact = false, eval_expression = false,
eval_module = @__MODULE__,
checkbounds = false,
kwargs...) where {iip}
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`")
end
Expand Down Expand Up @@ -480,7 +483,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys),

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)

SDEFunction{iip}(f, g,
SDEFunction{iip, specialize}(f, g,
sys = sys,
jac = _jac === nothing ? nothing : _jac,
tgrad = _tgrad === nothing ? nothing : _tgrad,
Expand All @@ -505,6 +508,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...)
SDEFunction{true}(sys, args...; kwargs...)
end

function DiffEqBase.SDEFunction{true}(sys::SDESystem, args...;
kwargs...)
SDEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function DiffEqBase.SDEFunction{false}(sys::SDESystem, args...;
kwargs...)
SDEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

"""
```julia
DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
Expand Down Expand Up @@ -583,14 +596,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...)
SDEFunctionExpr{true}(sys, args...; kwargs...)
end

function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys),
function DiffEqBase.SDEProblem{iip, specialize}(
sys::SDESystem, u0map = [], tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
sparsenoise = nothing, check_length = true,
callback = nothing, kwargs...) where {iip}
callback = nothing, kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
end
f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length,
f, u0, p = process_DEProblem(
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
kwargs...)
cbs = process_events(sys; callback, kwargs...)
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))
Expand Down Expand Up @@ -628,6 +643,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...)
SDEProblem{true}(sys, args...; kwargs...)
end

function DiffEqBase.SDEProblem(sys::SDESystem,
u0map::StaticArray,
args...;
kwargs...)
SDEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function DiffEqBase.SDEProblem{true}(sys::SDESystem, args...; kwargs...)
SDEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function DiffEqBase.SDEProblem{false}(sys::SDESystem, args...; kwargs...)
SDEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

"""
```julia
DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
Expand Down
28 changes: 28 additions & 0 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,31 @@ sys2 = complete(sys2)
prob = SDEProblem(sys1, sts .=> [1.0, 0.0, 0.0],
(0.0, 100.0), ps .=> (10.0, 26.0))
solve(prob, LambaEulerHeun(), seed = 1)

# SDEProblem construction with StaticArrays
# Issue#2814
@parameters p d
@variables x(tt)
@brownian a
eqs = [D(x) ~ p - d * x + a * sqrt(p)]
@mtkbuild sys = System(eqs, tt)
u0 = @SVector[x => 10.0]
tspan = (0.0, 10.0)
ps = @SVector[p => 5.0, d => 0.5]
sprob = SDEProblem(sys, u0, tspan, ps)
@test !isinplace(sprob)
@test !isinplace(sprob.f)
@test_nowarn solve(sprob, ImplicitEM())

# Ensure diagonal noise generates vector noise function
@variables y(tt)
@brownian b
eqs = [D(x) ~ p - d * x + a * sqrt(p)
D(y) ~ p - d * y + b * sqrt(d)]
@mtkbuild sys = System(eqs, tt)
u0 = @SVector[x => 10.0, y => 20.0]
tspan = (0.0, 10.0)
ps = @SVector[p => 5.0, d => 0.5]
sprob = SDEProblem(sys, u0, tspan, ps)
@test sprob.f.g(sprob.u0, sprob.p, sprob.tspan[1]) isa SVector{2, Float64}
@test_nowarn solve(sprob, ImplicitEM())

0 comments on commit d64f973

Please sign in to comment.