diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 09bea2d152..412d308c87 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -231,6 +231,9 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys), if isdde eqs = delay_to_function(sys, eqs) end + if eqs isa AbstractMatrix && isdiag(eqs) + eqs = diag(eqs) + end u = map(x -> time_varying_as_func(value(x), sys), dvs) p = if has_index_cache(sys) && get_index_cache(sys) !== nothing reorder_parameters(get_index_cache(sys), ps) @@ -403,14 +406,14 @@ function Girsanov_transform(sys::SDESystem, u; θ0 = 1.0) checks = false) end -function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys), +function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(sys), ps = parameters(sys), u0 = nothing; version = nothing, tgrad = false, sparse = false, jac = false, Wfact = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = false, - kwargs...) where {iip} + kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`") end @@ -480,7 +483,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = unknowns(sys), observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) - SDEFunction{iip}(f, g, + SDEFunction{iip, specialize}(f, g, sys = sys, jac = _jac === nothing ? nothing : _jac, tgrad = _tgrad === nothing ? nothing : _tgrad, @@ -505,6 +508,16 @@ function DiffEqBase.SDEFunction(sys::SDESystem, args...; kwargs...) SDEFunction{true}(sys, args...; kwargs...) end +function DiffEqBase.SDEFunction{true}(sys::SDESystem, args...; + kwargs...) + SDEFunction{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function DiffEqBase.SDEFunction{false}(sys::SDESystem, args...; + kwargs...) + SDEFunction{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end + """ ```julia DiffEqBase.SDEFunctionExpr{iip}(sys::AbstractODESystem, dvs = unknowns(sys), @@ -583,14 +596,16 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...) SDEFunctionExpr{true}(sys, args...; kwargs...) end -function DiffEqBase.SDEProblem{iip}(sys::SDESystem, u0map = [], tspan = get_tspan(sys), +function DiffEqBase.SDEProblem{iip, specialize}( + sys::SDESystem, u0map = [], tspan = get_tspan(sys), parammap = DiffEqBase.NullParameters(); sparsenoise = nothing, check_length = true, - callback = nothing, kwargs...) where {iip} + callback = nothing, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`") end - f, u0, p = process_DEProblem(SDEFunction{iip}, sys, u0map, parammap; check_length, + f, u0, p = process_DEProblem( + SDEFunction{iip, specialize}, sys, u0map, parammap; check_length, kwargs...) cbs = process_events(sys; callback, kwargs...) sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) @@ -628,6 +643,21 @@ function DiffEqBase.SDEProblem(sys::SDESystem, args...; kwargs...) SDEProblem{true}(sys, args...; kwargs...) end +function DiffEqBase.SDEProblem(sys::SDESystem, + u0map::StaticArray, + args...; + kwargs...) + SDEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...) +end + +function DiffEqBase.SDEProblem{true}(sys::SDESystem, args...; kwargs...) + SDEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) +end + +function DiffEqBase.SDEProblem{false}(sys::SDESystem, args...; kwargs...) + SDEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) +end + """ ```julia DiffEqBase.SDEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan, diff --git a/test/sdesystem.jl b/test/sdesystem.jl index b18ab648e7..b1786aa721 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -614,3 +614,31 @@ sys2 = complete(sys2) prob = SDEProblem(sys1, sts .=> [1.0, 0.0, 0.0], (0.0, 100.0), ps .=> (10.0, 26.0)) solve(prob, LambaEulerHeun(), seed = 1) + +# SDEProblem construction with StaticArrays +# Issue#2814 +@parameters p d +@variables x(tt) +@brownian a +eqs = [D(x) ~ p - d * x + a * sqrt(p)] +@mtkbuild sys = System(eqs, tt) +u0 = @SVector[x => 10.0] +tspan = (0.0, 10.0) +ps = @SVector[p => 5.0, d => 0.5] +sprob = SDEProblem(sys, u0, tspan, ps) +@test !isinplace(sprob) +@test !isinplace(sprob.f) +@test_nowarn solve(sprob, ImplicitEM()) + +# Ensure diagonal noise generates vector noise function +@variables y(tt) +@brownian b +eqs = [D(x) ~ p - d * x + a * sqrt(p) + D(y) ~ p - d * y + b * sqrt(d)] +@mtkbuild sys = System(eqs, tt) +u0 = @SVector[x => 10.0, y => 20.0] +tspan = (0.0, 10.0) +ps = @SVector[p => 5.0, d => 0.5] +sprob = SDEProblem(sys, u0, tspan, ps) +@test sprob.f.g(sprob.u0, sprob.p, sprob.tspan[1]) isa SVector{2, Float64} +@test_nowarn solve(sprob, ImplicitEM())