Skip to content

Commit

Permalink
Merge pull request #121 from SciML/secondorder
Browse files Browse the repository at this point in the history
Make zygote second order FD over Zygote
  • Loading branch information
Vaibhavdixit02 authored Oct 7, 2024
2 parents e10bed6 + c12dd3b commit 6043b0c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
gradient!, hessian!, hvp!, jacobian!, gradient, hessian,
hvp, jacobian, Constant
using ADTypes, SciMLBase
import Zygote
import Zygote, Zygote.ForwardDiff

function OptimizationBase.instantiate_function(
f::OptimizationFunction{true}, x,
Expand Down
13 changes: 11 additions & 2 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ Hessian is not defined via Zygote.
AutoZygote

function generate_adtype(adtype)
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder)
if !(adtype isa SciMLBase.NoAD || adtype isa DifferentiationInterface.SecondOrder ||
adtype isa AutoZygote)
soadtype = DifferentiationInterface.SecondOrder(adtype, adtype)
elseif adtype isa AutoZygote
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
elseif adtype isa DifferentiationInterface.SecondOrder
soadtype = adtype
adtype = adtype.inner
Expand All @@ -234,11 +237,17 @@ end

function spadtype_to_spsoadtype(adtype)
if !(adtype.dense_ad isa SciMLBase.NoAD ||
adtype.dense_ad isa DifferentiationInterface.SecondOrder)
adtype.dense_ad isa DifferentiationInterface.SecondOrder ||
adtype.dense_ad isa AutoZygote)
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
elseif adtype.dense_ad isa AutoZygote
soadtype = AutoSparse(
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
sparsity_detector = adtype.sparsity_detector,
coloring_algorithm = adtype.coloring_algorithm)
else
soadtype = adtype
end
Expand Down
11 changes: 9 additions & 2 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,18 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;

num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)

if !(prob.f.adtype isa DifferentiationInterface.SecondOrder) &&
if !(prob.f.adtype isa DifferentiationInterface.SecondOrder ||
prob.f.adtype isa AutoZygote) &&
(SciMLBase.requireshessian(opt) || SciMLBase.requiresconshess(opt) ||
SciMLBase.requireslagh(opt))
@warn "The selected optimization algorithm requires second order derivatives, but `SecondOrder` ADtype was not provided.
So a `SecondOrder` with $adtype for both inner and outer will be created, this can be suboptimal and not work in some cases so
So a `SecondOrder` with $(prob.f.adtype) for both inner and outer will be created, this can be suboptimal and not work in some cases so
an explicit `SecondOrder` ADtype is recommended."
elseif prob.f.adtype isa AutoZygote &&
(SciMLBase.requiresconshess(opt) || SciMLBase.requireslagh(opt) ||
SciMLBase.requireshessian(opt))
@warn "The selected optimization algorithm requires second order derivatives, but `AutoZygote` ADtype was provided.
So a `SecondOrder` with `AutoZygote` for inner and `AutoForwardDiff` for outer will be created, for choosing another pair
an explicit `SecondOrder` ADtype is recommended."
end

Expand Down
16 changes: 8 additions & 8 deletions test/adtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ optprob.cons_h(H3, x0)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = cons)
rosenbrock, AutoZygote(), cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 1, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -456,9 +456,9 @@ end
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(
rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()), cons = con2_c)
rosenbrock, AutoZygote(), cons = con2_c)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 2, g = true, h = true, hv = true,
cons_j = true, cons_h = true, cons_vjp = true,
cons_jvp = true, lag_h = true)
Expand Down Expand Up @@ -1080,10 +1080,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2]
optf = OptimizationFunction{false}(rosenbrock,
SecondOrder(AutoForwardDiff(), AutoZygote()),
AutoZygote(),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 1, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand All @@ -1096,10 +1096,10 @@ end

cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
optf = OptimizationFunction{false}(rosenbrock,
SecondOrder(AutoForwardDiff(), AutoZygote()),
AutoZygote(),
cons = cons)
optprob = OptimizationBase.instantiate_function(
optf, x0, SecondOrder(AutoForwardDiff(), AutoZygote()),
optf, x0, AutoZygote(),
nothing, 2, g = true, h = true, cons_j = true, cons_h = true)

@test optprob.grad(x0) == G1
Expand Down

0 comments on commit 6043b0c

Please sign in to comment.