Skip to content

Commit

Permalink
docs and minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Sep 3, 2024
1 parent 9a28474 commit 261fc4f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 23 deletions.
48 changes: 34 additions & 14 deletions src/ManifoldDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,44 @@ abstract type AbstractManifoldDiffEqAdaptiveAlgorithm <: AbstractManifoldDiffEqA
isadaptive(::AbstractManifoldDiffEqAdaptiveAlgorithm) = true

"""
struct ManifoldODESolution end
struct ManifoldODESolution{T} end
Counterpart of `SciMLBase.ODESolution`. It doesn't use the `N` parameter (because it
is not a generic manifold concept) and fields `u_analytic`, `errors`, `alg_choice`,
`original` and `resid` (because we don't use them currently in `ManifoldDiffEq.jl`).
`original`, `tslocation` and `resid` (because we don't use them currently in
`ManifoldDiffEq.jl`).
Type parameter `T` denotes scalar floating point type of the solution
Fields:
* `u`: the representation of the ODE solution. Uses a nested power manifold representation.
* `t`: time point at which values in `u` were calculated.
* `k`: the representation of the `f` function evaluations at time points `k`. Uses a nested
power manifold representation.
* `prob`: original problem that was solved.
* `alg`: [`AbstractManifoldDiffEqAlgorithm`](@ref) used to obtain the solution.
* `interp` [`ManifoldInterpolationData`](@ref)
* `dense`: `true` if ODE solution is saved at every step and `false` otherwise.
* `stats`: [`DEStats`](https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/#SciMLBase.DEStats) of solver
* `retcode`: [`ReturnCode`}(https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes) of the solution.
"""
mutable struct ManifoldODESolution{T,uType,tType,rateType,P,A,IType,S}
struct ManifoldODESolution{
T<:Real,
uType,
tType,
rateType,
P,
A<:AbstractManifoldDiffEqAlgorithm,
IType,
S,
}
u::uType
t::tType
k::rateType
prob::P
alg::A
interp::IType
dense::Bool
tslocation::Int
stats::S
retcode::ReturnCode.T
end
Expand All @@ -107,10 +130,9 @@ function ManifoldODESolution{T}(
alg,
interp,
dense,
tslocation,
stats,
retcode,
) where {T}
) where {T<:Real}
return ManifoldODESolution{
T,
typeof(u),
Expand All @@ -128,13 +150,12 @@ function ManifoldODESolution{T}(
alg,
interp,
dense,
tslocation,
stats,
retcode,
)
end

constructorof(::Type{<:ManifoldODESolution{T}}) where {T} = ManifoldODESolution{T}
constructorof(::Type{<:ManifoldODESolution{T}}) where {T<:Real} = ManifoldODESolution{T}

function solution_new_retcode(sol::ManifoldODESolution, retcode)
return @set sol.retcode = retcode
Expand Down Expand Up @@ -186,7 +207,7 @@ function SciMLBase.__init(
prob.tspan[1] in saveat,
save_end = nothing,
callback = nothing,
dense = save_everystep && isempty(saveat),
dense::Bool = save_everystep && isempty(saveat),
calck = (callback !== nothing && callback !== CallbackSet()) ||
(dense) ||
!isempty(saveat), # and no dense output
Expand All @@ -196,8 +217,8 @@ function SciMLBase.__init(
force_dtmin = false,
adaptive = isadaptive(alg),
gamma = gamma_default(alg),
abstol = nothing,
reltol = nothing,
abstol::Union{Nothing,Real} = nothing,
reltol::Union{Nothing,Real} = nothing,
qmin = qmin_default(alg),
qmax = qmax_default(alg),
qsteady_min = qsteady_min_default(alg),
Expand Down Expand Up @@ -334,7 +355,6 @@ function SciMLBase.__init(

ts = ts_init === () ? tType[] : convert(Vector{tType}, ts_init)
ks = ks_init === () ? ksEltype[] : convert(Vector{ksEltype}, ks_init)
alg_choice = nothing

if (!adaptive || !isadaptive(_alg)) && save_everystep && tspan[2] - tspan[1] != Inf
if dt == 0
Expand Down Expand Up @@ -376,14 +396,14 @@ function SciMLBase.__init(
k = rateType[]

if uses_uprev(_alg, adaptive) || calck
uprev = recursivecopy(u)
uprev = copy(M, u)
else
# Some algorithms do not use `uprev` explicitly. In that case, we can save
# some memory by aliasing `uprev = u`, e.g. for "2N" low storage methods.
uprev = u

Check warning on line 403 in src/ManifoldDiffEq.jl

View check run for this annotation

Codecov / codecov/patch

src/ManifoldDiffEq.jl#L403

Added line #L403 was not covered by tests
end
if allow_extrapolation
uprev2 = recursivecopy(u)
uprev2 = copy(M, u)

Check warning on line 406 in src/ManifoldDiffEq.jl

View check run for this annotation

Codecov / codecov/patch

src/ManifoldDiffEq.jl#L406

Added line #L406 was not covered by tests
else
uprev2 = uprev
end
Expand Down
4 changes: 2 additions & 2 deletions src/frozen_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ function perform_step!(integrator, ::ManifoldEulerCache, repeat_step = false)
t = integrator.t
alg = integrator.alg

k = integrator.f(u, integrator.p, t)
retract!(alg.manifold, u, u, k, integrator.dt, alg.retraction_method)
integrator.k[1] = integrator.f(u, integrator.p, t)
retract!(alg.manifold, u, u, integrator.k[1], integrator.dt, alg.retraction_method)

return integrator.stats.nf += 1
end
Expand Down
1 change: 0 additions & 1 deletion src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ function build_solution(
alg,
manifold_interp,
dense,
0,
stats,
retcode,
)
Expand Down
13 changes: 7 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ function test_solver_frozen(manifold_to_alg; expected_order = nothing, adaptive
@test alg_order(alg) == expected_order
end

@testset "Sphere" begin
M = Sphere(2)
M = Sphere(2)
alg = manifold_to_alg(M)
@testset "$alg on sphere" begin
A = FrozenManifoldDiffEqOperator{Float64}() do u, p, t
return cross(u, [1.0, 0.0, 0.0])
end
u0 = [0.0, 1.0, 0.0]
alg = manifold_to_alg(M)
prob = ManifoldODEProblem(A, u0, (0, 2.0), M)
sol1 = if adaptive
solve(prob, alg)
Expand All @@ -30,13 +30,14 @@ function test_solver_frozen(manifold_to_alg; expected_order = nothing, adaptive
@test is_point(M, sol1(1.0))
end

@testset "Product manifold" begin
M = ProductManifold(Sphere(2), Euclidean(3))
M = ProductManifold(Sphere(2), Euclidean(3))
alg = manifold_to_alg(M)
@testset "$alg on product manifold" begin

A = FrozenManifoldDiffEqOperator{Float64}() do u, p, t
return ArrayPartition(cross(u.x[1], [1.0, 0.0, 0.0]), u.x[2])
end
u0 = ArrayPartition([0.0, 1.0, 0.0], [1.0, 0.0, 0.0])
alg = manifold_to_alg(M)
prob = ManifoldODEProblem(A, u0, (0, 2.0), M)
sol1 = if adaptive
solve(prob, alg)
Expand Down

0 comments on commit 261fc4f

Please sign in to comment.