Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from AbstractDifferentiation to DifferentiationInterface #93

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
version = "0.6.0"
gdalle marked this conversation as resolved.
Show resolved Hide resolved

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"

[compat]
AbstractDifferentiation = "0.6"
ADTypes = "1.5.3"
DifferentiationInterface = "0.5.8"
LinearAlgebra = "1.2"
Printf = "1.2"
ProximalCore = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Implemented algorithms include:
Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.

Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients)
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients)
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).

## Documentation
Expand Down
8 changes: 4 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ using FileIO

const SUITE = BenchmarkGroup()

function ProximalAlgorithms.value_and_gradient_closure(
function ProximalAlgorithms.value_and_gradient(
f::ProximalOperators.LeastSquaresDirect,
x,
)
res = f.A * x - f.b
norm(res)^2 / 2, () -> f.A' * res
norm(res)^2 / 2, f.A' * res
end

struct SquaredDistance{Tb}
Expand All @@ -22,9 +22,9 @@ end

(f::SquaredDistance)(x) = norm(x - f.b)^2 / 2

function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
function ProximalAlgorithms.value_and_gradient(f::SquaredDistance, x)
diff = x - f.b
norm(diff)^2 / 2, () -> diff
norm(diff)^2 / 2, diff
end

for (benchmark_name, file_name) in [
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ end
mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2

using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalAlgorithms

training_loss = ProximalAlgorithms.AutoDifferentiable(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
ZygoteBackend(),
AutoZygote(),
)

# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
Expand Down
29 changes: 15 additions & 14 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
#
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, algorithms use [`value_and_gradient_closure`](@ref):
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
# To compute gradients, algorithms use [`value_and_gradient`](@ref):
# this relies on [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl), for automatic differentiation
# with any of its supported backends, when functions are wrapped in [`AutoDifferentiable`](@ref),
# as the examples below show.
#
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
# you can simply implement a method for [`value_and_gradient_closure`](@ref) on your own function type.
# you can simply implement a method for [`value_and_gradient`](@ref) on your own function type.
#
# ```@docs
# ProximalCore.prox
# ProximalCore.prox!
# ProximalAlgorithms.value_and_gradient_closure
# ProximalAlgorithms.value_and_gradient
# ProximalAlgorithms.AutoDifferentiable
# ```
#
Expand All @@ -32,12 +32,12 @@
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is

using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalAlgorithms

rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
ZygoteBackend(),
AutoZygote(),
)

# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
Expand Down Expand Up @@ -105,16 +105,17 @@ end

Counting(f::T) where {T} = Counting{T}(f, 0, 0, 0)

# Now we only need to intercept any call to [`value_and_gradient_closure`](@ref) and [`prox!`](@ref) and increase counters there:
function (f::Counting)(x)
f.eval_count += 1
return f.f(x)
end

function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
# Now we only need to intercept any call to [`value_and_gradient`](@ref) and [`prox!`](@ref) and increase counters there:

function ProximalAlgorithms.value_and_gradient(f::Counting, x)
f.eval_count += 1
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
function counting_pullback()
f.gradient_count += 1
return pb()
end
return fx, counting_pullback
f.gradient_count += 1
return ProximalAlgorithms.value_and_gradient(f.f, x)
end

function ProximalCore.prox!(y, f::Counting, x, gamma)
Expand Down
11 changes: 5 additions & 6 deletions docs/src/guide/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
#
# To evaluate these first-order primitives, in ProximalAlgorithms:
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) and all of its backends).
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
#
Expand Down Expand Up @@ -52,13 +52,13 @@

using LinearAlgebra
using Zygote
using AbstractDifferentiation: ZygoteBackend
using DifferentiationInterface: AutoZygote
using ProximalOperators
using ProximalAlgorithms

quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
ZygoteBackend(),
AutoZygote(),
)
box_indicator = ProximalOperators.IndBox(0, 1)

Expand All @@ -72,10 +72,9 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit = 1000, tol = 1e-5, verbose =
solution, iterations = ffb(x0 = ones(2), f = quadratic_cost, g = box_indicator)

# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).
# for this, we just evaluate the second output of [`value_and_gradient`](@ref).

v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
-cl()
last(ProximalAlgorithms.value_and_gradient(quadratic_cost, solution))

# Or by plotting the solution against the cost function and constraint:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Implemented algorithms include:
Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.

Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation (but you can easily bring your own gradients),
- [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) for automatic differentiation (but you can easily bring your own gradients),
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc, to handle non-differentiable terms (see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) for an extensive collection of functions).

!!! note
Expand Down
26 changes: 12 additions & 14 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ProximalAlgorithms

using AbstractDifferentiation
using ADTypes: ADTypes
using DifferentiationInterface: DifferentiationInterface
using ProximalCore
using ProximalCore: prox, prox!

Expand All @@ -12,33 +13,30 @@ const Maybe{T} = Union{T,Nothing}

Callable struct wrapping function `f` to be auto-differentiated using `backend`.

When called, it evaluates the same as `f`, while [`value_and_gradient_closure`](@ref)
When called, it evaluates the same as `f`, while its gradient
is implemented using `backend` for automatic differentiation.
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
The backend can be any of those supported by [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
"""
struct AutoDifferentiable{F,B}
struct AutoDifferentiable{F,B<:ADTypes.AbstractADType}
f::F
backend::B
end

(f::AutoDifferentiable)(x) = f.f(x)

"""
value_and_gradient_closure(f, x)
value_and_gradient(f, x)

Return a tuple containing the value of `f` at `x`, and a closure `cl`.

Function `cl`, once called, yields the gradient of `f` at `x`.
Return a tuple containing the value of `f` at `x` and the gradient of `f` at `x`.
"""
value_and_gradient_closure
value_and_gradient

function value_and_gradient_closure(f::AutoDifferentiable, x)
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
return fx, () -> pb(one(fx))[1]
function value_and_gradient(f::AutoDifferentiable, x)
return DifferentiationInterface.value_and_gradient(f.f, f.backend, x)
end

function value_and_gradient_closure(f::ProximalCore.Zero, x)
f(x), () -> zero(x)
function value_and_gradient(f::ProximalCore.Zero, x)
return f(x), zero(x)
end

# various utilities
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ end
function Base.iterate(iter::DavisYinIteration)
z = copy(iter.x0)
xg, = prox(iter.g, z, iter.gamma)
f_xg, cl = value_and_gradient_closure(iter.f, xg)
grad_f_xg = cl()
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
xh, = prox(iter.h, z_half, iter.gamma)
res = xh - xg
Expand All @@ -68,8 +67,8 @@ end

function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
prox!(state.xg, iter.g, state.z, iter.gamma)
f_xg, cl = value_and_gradient_closure(iter.f, state.xg)
state.grad_f_xg .= cl()
f_xg, grad_f_xg = value_and_gradient(iter.f, state.xg)
state.grad_f_xg .= grad_f_xg
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
prox!(state.xh, iter.h, state.z_half, iter.gamma)
state.res .= state.xh .- state.xg
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ end

function Base.iterate(iter::FastForwardBackwardIteration)
x = copy(iter.x0)
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
gamma =
iter.gamma === nothing ?
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
Expand Down Expand Up @@ -136,8 +135,8 @@ function Base.iterate(
state.x .= state.z .+ beta .* (state.z .- state.z_prev)
state.z_prev, state.z = state.z, state.z_prev

state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
state.grad_f_x .= cl()
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
state.grad_f_x .= grad_f_x
state.y .= state.x .- state.gamma .* state.grad_f_x
state.g_z = prox!(state.z, iter.g, state.y, state.gamma)
state.res .= state.x .- state.z
Expand Down
7 changes: 3 additions & 4 deletions src/algorithms/forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ end

function Base.iterate(iter::ForwardBackwardIteration)
x = copy(iter.x0)
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
gamma =
iter.gamma === nothing ?
1 / lower_bound_smoothness_constant(iter.f, I, x, grad_f_x) : iter.gamma
Expand Down Expand Up @@ -111,8 +110,8 @@ function Base.iterate(
state.grad_f_x, state.grad_f_z = state.grad_f_z, state.grad_f_x
else
state.x, state.z = state.z, state.x
state.f_x, cl = value_and_gradient_closure(iter.f, state.x)
state.grad_f_x .= cl()
state.f_x, grad_f_x = value_and_gradient(iter.f, state.x)
state.grad_f_x .= grad_f_x
end

state.y .= state.x .- state.gamma .* state.grad_f_x
Expand Down
10 changes: 4 additions & 6 deletions src/algorithms/li_lin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ end

function Base.iterate(iter::LiLinIteration{R}) where {R}
y = copy(iter.x0)
f_y, cl = value_and_gradient_closure(iter.f, y)
grad_f_y = cl()
f_y, grad_f_y = value_and_gradient(iter.f, y)

# TODO: initialize gamma if not provided
# TODO: authors suggest Barzilai-Borwein rule?
Expand Down Expand Up @@ -110,8 +109,7 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
else
# TODO: re-use available space in state?
# TODO: backtrack gamma at x
f_x, cl = value_and_gradient_closure(iter.f, x)
grad_f_x = cl()
f_x, grad_f_x = value_and_gradient(iter.f, x)
x_forward = state.x - state.gamma .* grad_f_x
v, g_v = prox(iter.g, x_forward, state.gamma)
Fv = iter.f(v) + g_v
Expand All @@ -130,8 +128,8 @@ function Base.iterate(iter::LiLinIteration{R}, state::LiLinState{R,Tx}) where {R
Fx = Fv
end

state.f_y, cl = value_and_gradient_closure(iter.f, state.y)
state.grad_f_y .= cl()
state.f_y, grad_f_y = value_and_gradient(iter.f, state.y)
state.grad_f_y .= grad_f_y
state.y_forward .= state.y .- state.gamma .* state.grad_f_y
state.g_z = prox!(state.z, iter.g, state.y_forward, state.gamma)

Expand Down
15 changes: 7 additions & 8 deletions src/algorithms/panoc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ f_model(iter::PANOCIteration, state::PANOCState) =
function Base.iterate(iter::PANOCIteration{R}) where {R}
x = copy(iter.x0)
Ax = iter.A * x
f_Ax, cl = value_and_gradient_closure(iter.f, Ax)
grad_f_Ax = cl()
f_Ax, grad_f_Ax = value_and_gradient(iter.f, Ax)
gamma =
iter.gamma === nothing ?
iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) :
Expand Down Expand Up @@ -182,8 +181,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R

state.x_d .= state.x .+ state.d
state.Ax_d .= state.Ax .+ state.Ad
state.f_Ax_d, cl = value_and_gradient_closure(iter.f, state.Ax_d)
state.grad_f_Ax_d .= cl()
state.f_Ax_d, grad_f_Ax_d = value_and_gradient(iter.f, state.Ax_d)
state.grad_f_Ax_d .= grad_f_Ax_d
mul!(state.At_grad_f_Ax_d, adjoint(iter.A), state.grad_f_Ax_d)

copyto!(state.x, state.x_d)
Expand Down Expand Up @@ -220,8 +219,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
# along a line using interpolation and linear combinations
# this allows saving operations
if isinf(f_Az)
f_Az, cl = value_and_gradient_closure(iter.f, state.Az)
state.grad_f_Az .= cl()
f_Az, grad_f_Az = value_and_gradient(iter.f, state.Az)
state.grad_f_Az .= grad_f_Az
end
if isinf(c)
mul!(state.At_grad_f_Az, iter.A', state.grad_f_Az)
Expand All @@ -239,8 +238,8 @@ function Base.iterate(iter::PANOCIteration{R,Tx,Tf}, state::PANOCState) where {R
else
# otherwise, in the general case where f is only smooth, we compute
# one gradient and matvec per backtracking step
state.f_Ax, cl = value_and_gradient_closure(iter.f, state.Ax)
state.grad_f_Ax .= cl()
state.f_Ax, grad_f_Ax = value_and_gradient(iter.f, state.Ax)
state.grad_f_Ax .= grad_f_Ax
mul!(state.At_grad_f_Ax, adjoint(iter.A), state.grad_f_Ax)
end

Expand Down
Loading
Loading