Skip to content

Commit

Permalink
Merge pull request #46 from SciML/adtypes1
Browse files Browse the repository at this point in the history
ADTypes v1.0 support updates
  • Loading branch information
Vaibhavdixit02 authored May 20, 2024
2 parents 0f2e93c + 3e26da5 commit e5ae416
Show file tree
Hide file tree
Showing 13 changed files with 374 additions and 183 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: IntegrationTest
on:
push:
branches: [master]
tags: [v*]
pull_request:

jobs:
test:
name: ${{ matrix.package.repo }}/${{ matrix.package.group }}/${{ matrix.julia-version }}
runs-on: ${{ matrix.os }}
env:
GROUP: ${{ matrix.package.group }}
strategy:
fail-fast: false
matrix:
julia-version: [1]
os: [ubuntu-latest]
package:
- {user: SciML, repo: Optimization.jl, group: Optimization}

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
arch: x64
- uses: julia-actions/julia-buildpkg@latest
- name: Clone Downstream
uses: actions/checkout@v4
with:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --color=yes --project=downstream {0}
run: |
using Pkg
# force it to use this PR's version of the package
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
Pkg.update()
Pkg.test(coverage=true) # resolver may fail with test time deps
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
file: lcov.info
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "1.0.0"
version = "1.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -35,10 +35,10 @@ OptimizationTrackerExt = "Tracker"
OptimizationZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.5, 1"
ADTypes = "1"
ArrayInterface = "7.6"
DocStringExtensions = "0.9"
Enzyme = "0.11.11, 0.12"
Enzyme = "0.11.11, =0.12.6"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26"
LinearAlgebra = "1.9, 1.10"
Expand Down
6 changes: 4 additions & 2 deletions ext/OptimizationFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,

if f.grad === nothing
gradcache = FD.GradientCache(x, x, adtype.fdtype)
grad = (res, θ, args...) -> FD.finite_difference_gradient!(res, x -> _f(x, args...),
grad = (res, θ, args...) -> FD.finite_difference_gradient!(
res, x -> _f(x, args...),
θ, gradcache)
else
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
Expand Down Expand Up @@ -125,7 +126,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},

if f.grad === nothing
gradcache = FD.GradientCache(cache.u0, cache.u0, adtype.fdtype)
grad = (res, θ, args...) -> FD.finite_difference_gradient!(res, x -> _f(x, args...),
grad = (res, θ, args...) -> FD.finite_difference_gradient!(
res, x -> _f(x, args...),
θ, gradcache)
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
Expand Down
8 changes: 4 additions & 4 deletions ext/OptimizationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x,
ForwardDiff.Chunk{chunksize}())
ForwardDiff.Chunk{chunksize}())
for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
Expand Down Expand Up @@ -143,7 +143,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], cache.u0,
ForwardDiff.Chunk{chunksize}())
ForwardDiff.Chunk{chunksize}())
for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
Expand Down Expand Up @@ -224,7 +224,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x,
ForwardDiff.Chunk{chunksize}())
ForwardDiff.Chunk{chunksize}())
for i in 1:num_cons]
cons_h = function (θ)
map(1:num_cons) do i
Expand Down Expand Up @@ -306,7 +306,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
hess_config_cache = [ForwardDiff.HessianConfig(fncs[i], x,
ForwardDiff.Chunk{chunksize}())
ForwardDiff.Chunk{chunksize}())
for i in 1:num_cons]
cons_h = function (θ)
map(1:num_cons) do i
Expand Down
116 changes: 109 additions & 7 deletions ext/OptimizationMTKExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,112 @@ module OptimizationMTKExt
import OptimizationBase, OptimizationBase.ArrayInterface
import OptimizationBase.SciMLBase
import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.ADTypes: AutoModelingToolkit
import OptimizationBase.ADTypes: AutoModelingToolkit, AutoSymbolics, AutoSparse
isdefined(Base, :get_extension) ? (using ModelingToolkit) : (using ..ModelingToolkit)

function OptimizationBase.instantiate_function(f, x, adtype::AutoModelingToolkit, p,
function OptimizationBase.ADTypes.AutoModelingToolkit(sparse = false, cons_sparse = false)
if sparse || cons_sparse
return AutoSparse(AutoSymbolics())
else
return AutoSymbolics()
end
end

function OptimizationBase.instantiate_function(
f, x, adtype::AutoSparse{<:AutoSymbolics, S, C}, p,
num_cons = 0) where {S, C}
p = isnothing(p) ? SciMLBase.NullParameters() : p

sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, x, p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
f = OptimizationProblem(sys, x, p, grad = true, hess = true,
sparse = true, cons_j = true, cons_h = true,
cons_sparse = true).f

grad = (G, θ, args...) -> f.grad(G, θ, p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, p, args...)

hv = function (H, θ, v, args...)
res = adtype.obj_sparse ? (eltype(θ)).(f.hess_prototype) :
ArrayInterface.zeromatrix(θ)
hess(res, θ, args...)
H .= res * v
end

if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, p)
cons_j = (J, θ) -> f.cons_j(J, θ, p)
cons_h = (res, θ) -> f.cons_h(res, θ, p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache,
adtype::AutoSparse{<:AutoSymbolics, S, C}, num_cons = 0) where {S, C}
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p

sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0,
cache.p;
lcons = fill(0.0,
num_cons),
ucons = fill(0.0,
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
f = OptimizationProblem(sys, cache.u0, cache.p, grad = true, hess = true,
sparse = true, cons_j = true, cons_h = true,
cons_sparse = true).f

grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)

hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)

hv = function (H, θ, v, args...)
res = adtype.obj_sparse ? (eltype(θ)).(f.hess_prototype) :
ArrayInterface.zeromatrix(θ)
hess(res, θ, args...)
H .= res * v
end

if !isnothing(f.cons)
cons = (res, θ) -> f.cons(res, θ, cache.p)
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
else
cons = nothing
cons_j = nothing
cons_h = nothing
end

return OptimizationFunction{true}(f.f, adtype; grad = grad, hess = hess, hv = hv,
cons = cons, cons_j = cons_j, cons_h = cons_h,
hess_prototype = f.hess_prototype,
cons_jac_prototype = f.cons_jac_prototype,
cons_hess_prototype = f.cons_hess_prototype,
expr = OptimizationBase.symbolify(f.expr),
cons_expr = OptimizationBase.symbolify.(f.cons_expr),
sys = sys,
observed = f.observed)
end

function OptimizationBase.instantiate_function(f, x, adtype::AutoSymbolics, p,
num_cons = 0)
p = isnothing(p) ? SciMLBase.NullParameters() : p

Expand All @@ -17,8 +119,8 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoModelingToolkit
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
f = OptimizationProblem(sys, x, p, grad = true, hess = true,
sparse = adtype.obj_sparse, cons_j = true, cons_h = true,
cons_sparse = adtype.cons_sparse).f
sparse = false, cons_j = true, cons_h = true,
cons_sparse = false).f

grad = (G, θ, args...) -> f.grad(G, θ, p, args...)

Expand Down Expand Up @@ -53,7 +155,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoModelingToolkit
end

function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache,
adtype::AutoModelingToolkit, num_cons = 0)
adtype::AutoSymbolics, num_cons = 0)
p = isnothing(cache.p) ? SciMLBase.NullParameters() : cache.p

sys = complete(ModelingToolkit.modelingtoolkitize(OptimizationProblem(f, cache.u0,
Expand All @@ -64,8 +166,8 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
num_cons))))
#sys = ModelingToolkit.structural_simplify(sys)
f = OptimizationProblem(sys, cache.u0, cache.p, grad = true, hess = true,
sparse = adtype.obj_sparse, cons_j = true, cons_h = true,
cons_sparse = adtype.cons_sparse).f
sparse = false, cons_j = true, cons_h = true,
cons_sparse = false).f

grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)

Expand Down
32 changes: 16 additions & 16 deletions ext/OptimizationReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
xdual = ForwardDiff.Dual{
typeof(T),
eltype(x),
chunksize,
chunksize
}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
Expand Down Expand Up @@ -119,9 +119,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i],
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
Expand Down Expand Up @@ -182,7 +182,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
xdual = ForwardDiff.Dual{
typeof(T),
eltype(cache.u0),
chunksize,
chunksize
}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), chunksize)...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
Expand Down Expand Up @@ -255,9 +255,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i],
cache.u0,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
cache.u0,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
Expand Down Expand Up @@ -319,7 +319,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
xdual = ForwardDiff.Dual{
typeof(T),
eltype(x),
chunksize,
chunksize
}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
Expand Down Expand Up @@ -393,9 +393,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i],
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
cons_h = function (θ)
map(1:num_cons) do i
ForwardDiff.jacobian(gs[i], θ, jaccfgs[i], Val{false}())
Expand Down Expand Up @@ -456,7 +456,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
xdual = ForwardDiff.Dual{
typeof(T),
eltype(x),
chunksize,
chunksize
}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), chunksize)...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
Expand Down Expand Up @@ -530,9 +530,9 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i],
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
x,
ForwardDiff.Chunk{chunksize}(),
T) for i in 1:num_cons]
cons_h = function (θ)
map(1:num_cons) do i
ForwardDiff.jacobian(gs[i], θ, jaccfgs[i], Val{false}())
Expand Down
8 changes: 4 additions & 4 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ module OptimizationSparseDiffExt

import OptimizationBase, OptimizationBase.ArrayInterface
import OptimizationBase.SciMLBase: OptimizationFunction
import OptimizationBase.ADTypes: AutoSparseForwardDiff,
AutoSparseFiniteDiff, AutoSparseReverseDiff
import OptimizationBase.ADTypes: AutoSparse, AutoFiniteDiff, AutoForwardDiff,
AutoReverseDiff
using OptimizationBase.LinearAlgebra, ReverseDiff
isdefined(Base, :get_extension) ?
(using SparseDiffTools,
SparseDiffTools.ForwardDiff, SparseDiffTools.FiniteDiff, Symbolics) :
SparseDiffTools.ForwardDiff, SparseDiffTools.FiniteDiff, Symbolics) :
(using ..SparseDiffTools,
..SparseDiffTools.ForwardDiff, ..SparseDiffTools.FiniteDiff, ..Symbolics)
..SparseDiffTools.ForwardDiff, ..SparseDiffTools.FiniteDiff, ..Symbolics)

function default_chunk_size(len)
if len < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
Expand Down
Loading

0 comments on commit e5ae416

Please sign in to comment.