Skip to content

Commit

Permalink
Merge pull request #118 from SciML/secondorder
Browse files Browse the repository at this point in the history
Remove automatic FoR `soadtype` creations
  • Loading branch information
Vaibhavdixit02 authored Oct 5, 2024
2 parents 1cb8a90 + 40208b6 commit e10bed6
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 116 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "2.2.1"
version = "2.3.0"


[deps]
Expand Down
10 changes: 8 additions & 2 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ using ADTypes, SciMLBase
import Zygote

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoZygote,
f::OptimizationFunction{true}, x,
adtype::Union{ADTypes.AutoZygote,
DifferentiationInterface.SecondOrder{
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}},
p = SciMLBase.NullParameters(), num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
Expand Down Expand Up @@ -280,7 +283,10 @@ function OptimizationBase.instantiate_function(
end

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AutoZygote},
f::OptimizationFunction{true}, x,
adtype::ADTypes.AutoSparse{<:Union{ADTypes.AutoZygote,
DifferentiationInterface.SecondOrder{
<:ADTypes.AbstractADType, <:ADTypes.AutoZygote}}},
p = SciMLBase.NullParameters(), num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
cons_j = false, cons_vjp = false, cons_jvp = false, cons_h = false,
Expand Down
114 changes: 36 additions & 78 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,102 +220,60 @@ Hessian is not defined via Zygote.
AutoZygote

function generate_adtype(adtype)
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
else
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
elseif adtype isa DifferentiationInterface.SecondOrder
soadtype = adtype
adtype = adtype.inner
elseif adtype isa SciMLBase.NoAD
soadtype = adtype
adtype = adtype
end
return adtype, soadtype
end

function generate_sparse_adtype(adtype)
function spadtype_to_spsoadtype(adtype)
if !(adtype.dense_ad isa SciMLBase.NoAD ||
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
else
soadtype = adtype
end
return soadtype
end

function filled_spad(adtype)
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = GreedyColoringAlgorithm())
end
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = TracerSparsityDetector(),
coloring_algorithm = adtype.coloring_algorithm)
end
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = GreedyColoringAlgorithm())
end
end
end

function generate_sparse_adtype(adtype)
if !(adtype.dense_ad isa DifferentiationInterface.SecondOrder)
adtype = filled_spad(adtype)
soadtype = spadtype_to_spsoadtype(adtype)
else
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif !(adtype isa SciMLBase.NoAD) &&
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
end
soadtype = adtype
adtype = AutoSparse(
adtype.dense_ad.inner,
sparsity_detector = soadtype.sparsity_detector,
coloring_algorithm = soadtype.coloring_algorithm)
adtype = filled_spad(adtype)
soadtype = filled_spad(soadtype)
end

return adtype, soadtype
end
8 changes: 8 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;

num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)

if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
SciMLBase.requireslagh(opt))
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
So a `SecondOrder` with $adtype for both inner and outer will be created, this can be suboptimal and not work in some cases so
an explicit `SecondOrder` ADtype is recommended."
end

f = OptimizationBase.instantiate_function(
prob.f, reinit_cache, prob.f.adtype, num_cons;
g = SciMLBase.requiresgradient(opt), h = SciMLBase.requireshessian(opt),
Expand Down
55 changes: 35 additions & 20 deletions test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ optprob.cons_h(H3, x0)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)
Expand Down Expand Up @@ -142,7 +142,7 @@ optprob.cons_h(H3, x0)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)
Expand Down Expand Up @@ -179,7 +179,7 @@ optprob.cons_h(H3, x0)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)
Expand Down Expand Up @@ -217,14 +217,15 @@ optprob.cons_h(H3, x0)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = cons)
optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoZygote(),
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
nothing, 1, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -254,14 +255,19 @@ optprob.cons_h(H3, x0)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6

G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = cons)
optf = OptimizationFunction(rosenbrock,
DifferentiationInterface.SecondOrder(
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoFiniteDiff(),
optf, x0,
DifferentiationInterface.SecondOrder(
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
nothing, 1, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand All @@ -287,11 +293,12 @@ optprob.cons_h(H3, x0)
H3 = [Array{Float64}(undef, 2, 2)]
optprob.cons_h(H3, x0)
@test H3[[2.0 0.0; 0.0 2.0]] rtol=1e-5
Random.seed!(123)
H4 = Array{Float64}(undef, 2, 2)
μ = randn(1)
σ = rand()
optprob.lag_h(H4, x0, σ, μ)
@test H4σ * H1 + μ[1] * H3[1] rtol=1e-6
@test H4σ * H2 + μ[1] * H3[1] rtol=1e-6
end

@testset "two constraints tests" begin
Expand Down Expand Up @@ -448,9 +455,10 @@ end
G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(rosenbrock, OptimizationBase.AutoZygote(), cons = con2_c)
optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoZygote(),
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
nothing, 2, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -486,9 +494,13 @@ end
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(
rosenbrock, OptimizationBase.AutoFiniteDiff(), cons = con2_c)
rosenbrock, DifferentiationInterface.SecondOrder(
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
cons = con2_c)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoFiniteDiff(),
optf, x0,
DifferentiationInterface.SecondOrder(
ADTypes.AutoFiniteDiff(), ADTypes.AutoReverseDiff()),
nothing, 2, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -734,12 +746,15 @@ end
@test lag_H lag_H_expected
@test nnz(lag_H) == 5

optf = OptimizationFunction(sparse_objective, OptimizationBase.AutoSparseZygote(),
optf = OptimizationFunction(sparse_objective,
AutoSparse(DifferentiationInterface.SecondOrder(
ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
cons = sparse_constraints)

# Instantiate the optimization problem
optprob = OptimizationBase.instantiate_function(optf, x0,
OptimizationBase.AutoSparseZygote(),
AutoSparse(DifferentiationInterface.SecondOrder(
ADTypes.AutoForwardDiff(), ADTypes.AutoZygote())),
nothing, 2, g = true, h = true, cons_j = true, cons_h = true, lag_h = true)
# Test gradient
G = zeros(3)
Expand Down Expand Up @@ -1065,10 +1080,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2]
optf = OptimizationFunction{false}(rosenbrock,
OptimizationBase.AutoZygote(),
SecondOrder(AutoForwardDiff(), AutoZygote()),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoZygote(),
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand All @@ -1081,10 +1096,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
optf = OptimizationFunction{false}(rosenbrock,
OptimizationBase.AutoZygote(),
SecondOrder(AutoForwardDiff(), AutoZygote()),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, OptimizationBase.AutoZygote(),
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand Down
24 changes: 9 additions & 15 deletions test/matrixvalued.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ using OptimizationBase, LinearAlgebra, ForwardDiff, Zygote, FiniteDiff,
using Test, ReverseDiff

@testset "Matrix Valued" begin
for adtype in [AutoForwardDiff(), AutoZygote(), AutoFiniteDiff(),
for adtype in [AutoForwardDiff(), SecondOrder(AutoForwardDiff(), AutoZygote()),
SecondOrder(AutoForwardDiff(), AutoFiniteDiff()),
AutoSparse(AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()),
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()),
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())]
AutoSparse(SecondOrder(AutoForwardDiff(), AutoZygote()),
sparsity_detector = TracerLocalSparsityDetector()),
AutoSparse(SecondOrder(AutoForwardDiff(), AutoFiniteDiff()),
sparsity_detector = TracerLocalSparsityDetector())]
# 1. Matrix Factorization
@show adtype
function matrix_factorization_objective(X, A)
U, V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]),
@view(X[1:size(A, 1), (Int(size(A, 2) / 2) + 1):size(A, 2)])
Expand All @@ -28,12 +32,7 @@ using Test, ReverseDiff
cons_j = true, cons_h = true)
optf.grad(hcat(U_mf, V_mf))
optf.hess(hcat(U_mf, V_mf))
if adtype != AutoSparse(
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
adtype !=
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
adtype !=
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
if !(adtype isa ADTypes.AutoSparse)
optf.cons_j(hcat(U_mf, V_mf))
optf.cons_h(hcat(U_mf, V_mf))
end
Expand All @@ -55,12 +54,7 @@ using Test, ReverseDiff
cons_j = true, cons_h = true)
optf.grad(X_pca)
optf.hess(X_pca)
if adtype != AutoSparse(
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
adtype !=
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
adtype !=
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
if !(adtype isa ADTypes.AutoSparse)
optf.cons_j(X_pca)
optf.cons_h(X_pca)
end
Expand Down

0 comments on commit e10bed6

Please sign in to comment.