Skip to content

Commit

Permalink
Update gradient interface, support AbstractDifferentiation (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Jan 21, 2024
1 parent 4d07b28 commit 9067e90
Show file tree
Hide file tree
Showing 37 changed files with 382 additions and 617 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
name = "ProximalAlgorithms"
uuid = "140ffc9f-1907-541a-a177-7475e0a401e9"
version = "0.5.5"
version = "0.6.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractDifferentiation = "0.6"
LinearAlgebra = "1.2"
Printf = "1.2"
ProximalCore = "0.1"
Zygote = "0.6"
julia = "1.2"
23 changes: 15 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ A Julia package for non-smooth optimization algorithms.
This package provides algorithms for the minimization of objective functions
that include non-smooth terms, such as constraints or non-differentiable penalties.
Implemented algorithms include:
* (Fast) Proximal gradient methods
* Douglas-Rachford splitting
* Three-term splitting
* Primal-dual splitting algorithms
* Newton-type methods

This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15),
which contains a wide range of functions that can be used to express cost terms.
- (Fast) Proximal gradient methods
- Douglas-Rachford splitting
- Three-term splitting
- Primal-dual splitting algorithms
- Newton-type methods

Check out [this section](https://juliafirstorder.github.io/ProximalAlgorithms.jl/stable/guide/implemented_algorithms/) for an overview of the available algorithms.

Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation
(but you can easily bring your own gradients)
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc,
to handle non-differentiable terms
(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl)
for an extensive collection of functions).

## Documentation

Expand Down
22 changes: 19 additions & 3 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ using FileIO

const SUITE = BenchmarkGroup()

function ProximalAlgorithms.value_and_gradient_closure(f::ProximalOperators.LeastSquaresDirect, x)
res = f.A*x - f.b
norm(res)^2, () -> f.A'*res
end

struct SquaredDistance{Tb}
b::Tb
end

(f::SquaredDistance)(x) = norm(x - f.b)^2

function ProximalAlgorithms.value_and_gradient_closure(f::SquaredDistance, x)
diff = x - f.b
norm(diff)^2, () -> diff
end

for (benchmark_name, file_name) in [
("Lasso tiny", joinpath(@__DIR__, "data", "lasso_tiny.jld2")),
("Lasso small", joinpath(@__DIR__, "data", "lasso_small.jld2")),
Expand Down Expand Up @@ -42,21 +58,21 @@ for (benchmark_name, file_name) in [
SUITE[k]["ZeroFPR"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
solver = ProximalAlgorithms.ZeroFPR(tol=1e-6)
x0 = zeros($T, size($A, 2))
f = Translate(SqrNormL2(), -$b)
f = SquaredDistance($b)
g = NormL1($lam)
end

SUITE[k]["PANOC"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
solver = ProximalAlgorithms.PANOC(tol=1e-6)
x0 = zeros($T, size($A, 2))
f = Translate(SqrNormL2(), -$b)
f = SquaredDistance($b)
g = NormL1($lam)
end

SUITE[k]["PANOCplus"] = @benchmarkable solver(x0=x0, f=f, A=$A, g=g) setup=begin
solver = ProximalAlgorithms.PANOCplus(tol=1e-6)
x0 = zeros($T, size($A, 2))
f = Translate(SqrNormL2(), -$b)
f = SquaredDistance($b)
g = NormL1($lam)
end

Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
Expand All @@ -7,6 +8,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProximalAlgorithms = "140ffc9f-1907-541a-a177-7475e0a401e9"
ProximalCore = "dc4f5ac2-75d1-4f31-931e-60435d74994b"
ProximalOperators = "a725b495-10eb-56fe-b38b-717eba820537"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Documenter = "1"
Expand Down
7 changes: 5 additions & 2 deletions docs/src/examples/sparse_linear_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ end

mean_squared_error(label, output) = mean((output .- label) .^ 2) / 2

using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalAlgorithms

training_loss = ProximalAlgorithms.ZygoteFunction(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input))
training_loss = ProximalAlgorithms.AutoDifferentiable(
wb -> mean_squared_error(training_label, standardized_linear_model(wb, training_input)),
ZygoteBackend()
)

# As regularization we will use the L1 norm, implemented in [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl):
Expand Down
44 changes: 27 additions & 17 deletions docs/src/guide/custom_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,32 @@
#
# Defining the proximal mapping for a custom function type requires adding a method for [`ProximalCore.prox!`](@ref).
#
# To compute gradients, ProximalAlgorithms provides a fallback definition for [`ProximalCore.gradient!`](@ref),
# relying on [Zygote](https://github.com/FluxML/Zygote.jl) to use automatic differentiation.
# Therefore, you can provide any (differentiable) Julia function wherever gradients need to be taken,
# and everything will work out of the box.
# To compute gradients, algorithms use [`ProximalAlgorithms.value_and_gradient_closure`](@ref):
# this relies on [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl), for automatic differentiation
# with any of its supported backends, when functions are wrapped in [`ProximalAlgorithms.AutoDifferentiable`](@ref),
# as the examples below show.
#
# If however one would like to provide their own gradient implementation (e.g. for efficiency reasons),
# they can simply implement a method for [`ProximalCore.gradient!`](@ref).
# If however you would like to provide your own gradient implementation (e.g. for efficiency reasons),
# you can simply implement a method for [`ProximalAlgorithms.value_and_gradient_closure`](@ref) on your own function type.
#
# ```@docs
# ProximalCore.prox
# ProximalCore.prox!
# ProximalCore.gradient
# ProximalCore.gradient!
# ProximalAlgorithms.value_and_gradient_closure
# ProximalAlgorithms.AutoDifferentiable
# ```
#
# ## Example: constrained Rosenbrock
#
# Let's try to minimize the celebrated Rosenbrock function, but constrained to the unit norm ball. The cost function is

using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalAlgorithms

rosenbrock2D = ProximalAlgorithms.ZygoteFunction(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2
rosenbrock2D = ProximalAlgorithms.AutoDifferentiable(
x -> 100 * (x[2] - x[1]^2)^2 + (1 - x[1])^2,
ZygoteBackend()
)

# To enforce the constraint, we define the indicator of the unit ball, together with its proximal mapping:
Expand Down Expand Up @@ -82,17 +85,23 @@ scatter!([solution[1]], [solution[2]], color=:red, markershape=:star5, label="co

mutable struct Counting{T}
f::T
eval_count::Int
gradient_count::Int
prox_count::Int
end

Counting(f::T) where T = Counting{T}(f, 0, 0)
Counting(f::T) where T = Counting{T}(f, 0, 0, 0)

# Now we only need to intercept any call to `gradient!` and `prox!` and increase counters there:
# Now we only need to intercept any call to `value_and_gradient_closure` and `prox!` and increase counters there:

function ProximalCore.gradient!(y, f::Counting, x)
f.gradient_count += 1
return ProximalCore.gradient!(y, f.f, x)
function ProximalAlgorithms.value_and_gradient_closure(f::Counting, x)
f.eval_count += 1
fx, pb = ProximalAlgorithms.value_and_gradient_closure(f.f, x)
function counting_pullback()
f.gradient_count += 1
return pb()
end
return fx, counting_pullback
end

function ProximalCore.prox!(y, f::Counting, x, gamma)
Expand All @@ -109,5 +118,6 @@ solution, iterations = panoc(x0=-ones(2), f=f, g=g)

# and check how many operations where actually performed:

println(f.gradient_count)
println(g.prox_count)
println("function evals: $(f.eval_count)")
println("gradient evals: $(f.gradient_count)")
println(" prox evals: $(g.prox_count)")
13 changes: 9 additions & 4 deletions docs/src/guide/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# The literature on proximal operators and algorithms is vast: for an overview, one can refer to [Parikh2014](@cite), [Beck2017](@cite).
#
# To evaluate these first-order primitives, in ProximalAlgorithms:
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [Zygote](https://github.com/FluxML/Zygote.jl)).
# * ``\nabla f_i`` falls back to using automatic differentiation (as provided by [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl) and all of its backends).
# * ``\operatorname{prox}_{f_i}`` relies on the intereface of [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15).
# Both of the above can be implemented for custom function types, as [documented here](@ref custom_terms).
#
Expand Down Expand Up @@ -51,11 +51,14 @@
# which we will solve using the fast proximal gradient method (also known as fast forward-backward splitting):

using LinearAlgebra
using Zygote
using AbstractDifferentiation: ZygoteBackend
using ProximalOperators
using ProximalAlgorithms

quadratic_cost = ProximalAlgorithms.ZygoteFunction(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x)
quadratic_cost = ProximalAlgorithms.AutoDifferentiable(
x -> dot([3.4 1.2; 1.2 4.5] * x, x) / 2 + dot([-2.3, 9.9], x),
ZygoteBackend()
)
box_indicator = ProximalOperators.IndBox(0, 1)

Expand All @@ -69,8 +72,10 @@ ffb = ProximalAlgorithms.FastForwardBackward(maxit=1000, tol=1e-5, verbose=true)
solution, iterations = ffb(x0=ones(2), f=quadratic_cost, g=box_indicator)

# We can verify the correctness of the solution by checking that the negative gradient is orthogonal to the constraints, pointing outwards:
# for this, we just evaluate the closure `cl` returned as second output of [`value_and_gradient_closure`](@ref).

-ProximalAlgorithms.gradient(quadratic_cost, solution)[1]
v, cl = ProximalAlgorithms.value_and_gradient_closure(quadratic_cost, solution)
-cl()

# Or by plotting the solution against the cost function and constraint:

Expand Down
29 changes: 13 additions & 16 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ A Julia package for non-smooth optimization algorithms. [Link to GitHub reposito
This package provides algorithms for the minimization of objective functions
that include non-smooth terms, such as constraints or non-differentiable penalties.
Implemented algorithms include:
* (Fast) Proximal gradient methods
* Douglas-Rachford splitting
* Three-term splitting
* Primal-dual splitting algorithms
* Newton-type methods
- (Fast) Proximal gradient methods
- Douglas-Rachford splitting
- Three-term splitting
- Primal-dual splitting algorithms
- Newton-type methods

Check out [this section](@ref problems_algorithms) for an overview of the available algorithms.

This package works well in combination with [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl) (>= 0.15),
which contains a wide range of functions that can be used to express cost terms.
Algorithms rely on:
- [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) for automatic differentiation
(but you can easily bring your own gradients)
- the [ProximalCore API](https://github.com/JuliaFirstOrder/ProximalCore.jl) for proximal mappings, projections, etc,
to handle non-differentiable terms
(see for example [ProximalOperators](https://github.com/JuliaFirstOrder/ProximalOperators.jl)
for an extensive collection of functions).

!!! note

Expand All @@ -23,20 +28,11 @@ which contains a wide range of functions that can be used to express cost terms.

## Installation

Install the latest stable release with

```julia
julia> ]
pkg> add ProximalAlgorithms
```

To install the development version instead (`master` branch), do

```julia
julia> ]
pkg> add ProximalAlgorithms#master
```

## Citing

If you use any of the algorithms from ProximalAlgorithms in your research, you are kindly asked to cite the relevant bibliography.
Expand All @@ -45,3 +41,4 @@ Please check [this section of the manual](@ref problems_algorithms) for algorith
## Contributing

Contributions are welcome in the form of [issue notifications](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/issues) or [pull requests](https://github.com/JuliaFirstOrder/ProximalAlgorithms.jl/pulls). When contributing new algorithms, we highly recommend looking at already implemented ones to get inspiration on how to structure the code.

19 changes: 19 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
julia:
julia --project=.

instantiate:
julia --project=. -e 'using Pkg; Pkg.instantiate()'

test:
julia --project=. -e 'using Pkg; Pkg.test()'

format:
julia --project=. -e 'using JuliaFormatter: format; format(".")'

docs:
julia --project=./docs docs/make.jl

benchmark:
julia --project=benchmark -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
julia --project=benchmark benchmark/runbenchmarks.jl

38 changes: 36 additions & 2 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,48 @@
module ProximalAlgorithms

using AbstractDifferentiation
using ProximalCore
using ProximalCore: prox, prox!, gradient, gradient!
using ProximalCore: prox, prox!

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}

"""
AutoDifferentiable(f, backend)
Callable struct wrapping function `f` to be auto-differentiated using `backend`.
When called, it evaluates the same as `f`, while [`ProximalAlgorithms.value_and_gradient_closure`](@ref)
is implemented using `backend` for automatic differentiation.
The backend can be any from [AbstractDifferentiation](https://github.com/JuliaDiff/AbstractDifferentiation.jl).
"""
struct AutoDifferentiable{F, B}
f::F
backend::B
end

(f::AutoDifferentiable)(x) = f.f(x)

"""
value_and_gradient_closure(f, x)
Return a tuple containing the value of `f` at `x`, and a closure `cl`.
Function `cl`, once called, yields the gradient of `f` at `x`.
"""
value_and_gradient_closure

function value_and_gradient_closure(f::AutoDifferentiable, x)
fx, pb = AbstractDifferentiation.value_and_pullback_function(f.backend, f.f, x)
return fx, () -> pb(one(fx))[1]
end

function value_and_gradient_closure(f::ProximalCore.Zero, x)
f(x), () -> zero(x)
end

# various utilities

include("utilities/ad.jl")
include("utilities/fb_tools.jl")
include("utilities/iteration_tools.jl")

Expand Down
6 changes: 4 additions & 2 deletions src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ end
function Base.iterate(iter::DavisYinIteration)
z = copy(iter.x0)
xg, = prox(iter.g, z, iter.gamma)
grad_f_xg, = gradient(iter.f, xg)
f_xg, cl = value_and_gradient_closure(iter.f, xg)
grad_f_xg = cl()
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
xh, = prox(iter.h, z_half, iter.gamma)
res = xh - xg
Expand All @@ -66,7 +67,8 @@ end

function Base.iterate(iter::DavisYinIteration, state::DavisYinState)
prox!(state.xg, iter.g, state.z, iter.gamma)
gradient!(state.grad_f_xg, iter.f, state.xg)
f_xg, cl = value_and_gradient_closure(iter.f, state.xg)
state.grad_f_xg .= cl()
state.z_half .= 2 .* state.xg .- state.z .- iter.gamma .* state.grad_f_xg
prox!(state.xh, iter.h, state.z_half, iter.gamma)
state.res .= state.xh .- state.xg
Expand Down
Loading

0 comments on commit 9067e90

Please sign in to comment.