Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manifold optimization #435

Merged
merged 26 commits into from
Sep 22, 2017
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2b9b7f9
Manifold optimization
antoine-levitt Jun 22, 2017
9aa58c1
add tests
antoine-levitt Jun 22, 2017
0ea8ff5
add gradient(obj,int) to ManifoldObjective
antoine-levitt Jun 22, 2017
765e574
do not modify initial x, and fix LBFGS initialization
antoine-levitt Jun 22, 2017
7e64c2a
product and power manifolds
antoine-levitt Jun 22, 2017
dd106e2
support complex optimization
antoine-levitt Jun 23, 2017
74cd17b
Merge branch 'complex' into manifold
antoine-levitt Jun 24, 2017
02f05f6
fix tests, and import iscomplex the correct way
antoine-levitt Jun 24, 2017
7b7002e
simplify slightly complex API, and support complex manifolds
antoine-levitt Jun 25, 2017
1f8315e
nicer power and product manifolds
antoine-levitt Jun 25, 2017
2aadf79
implement gradient! in manifoldobjective
antoine-levitt Jun 25, 2017
cff44ec
Merge branch 'master' of https://github.com/JuliaNLSolvers/Optim.jl i…
antoine-levitt Jul 22, 2017
583997f
doc
antoine-levitt Jul 23, 2017
1feb7b6
whitespace
antoine-levitt Jul 23, 2017
1bb787b
convert back to complex for preconditioners
antoine-levitt Jul 23, 2017
018cc81
document complex and manifolds
antoine-levitt Jul 23, 2017
0cba7ee
whitespace & overloading P for complex optimization
antoine-levitt Jul 23, 2017
2878095
LBFGS's twoloop! expects vectorized arguments => clashes with precond…
antoine-levitt Jul 25, 2017
376178b
simplify discussion of second-order complex, and remove discussion of…
antoine-levitt Jul 26, 2017
e71a6ed
Merge branch 'manifold' of github.com:antoine-levitt/Optim.jl into ma…
antoine-levitt Jul 26, 2017
2ae30ee
change "does nothing"
antoine-levitt Jul 26, 2017
4cafc1a
clarify complex docs
antoine-levitt Jul 26, 2017
e9c8566
document manifolds
antoine-levitt Jul 26, 2017
c1f1bc5
add "in general"
antoine-levitt Jul 26, 2017
0d8a1ac
remove extra copy
antoine-levitt Jul 27, 2017
44a186b
Merge remote-tracking branch 'origin/master' into manifold
antoine-levitt Sep 20, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions src/Manifolds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Manifold interface: every manifold (subtype of Manifold) defines the functions
# project_tangent!(m, g, x): project g on the tangent space to m at x
# retract!(m, x): map x back to a point on the manifold m

## To add:
## * Second order algorithms
## * Vector transport
## * Arbitrary inner product
## * More retractions
## * More manifolds from ROPTLIB
## * {x, Ax = b}
## * Intersection manifold (just do the projection on both manifolds iteratively and hope it converges)

abstract type Manifold
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to add docstrings at least to all the new types

end


type ManifoldObjective{T<:NLSolversBase.AbstractObjective} <: NLSolversBase.AbstractObjective
manifold :: Manifold
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the space around :: seems quite unusual

inner_obj :: T
end
iscomplex(obj::ManifoldObjective) = iscomplex(obj.inner_obj)
# TODO is it safe here to call retract! and change x?
function NLSolversBase.value!(obj::ManifoldObjective, x)
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x)))
value!(obj.inner_obj, xin)
end
function NLSolversBase.value(obj::ManifoldObjective)
value(obj.inner_obj)
end
function NLSolversBase.gradient(obj::ManifoldObjective)
gradient(obj.inner_obj)
end
function NLSolversBase.gradient(obj::ManifoldObjective,i::Int)
gradient(obj.inner_obj,i)
end
function NLSolversBase.gradient!(obj::ManifoldObjective,x)
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x)))
gradient!(obj.inner_obj,xin)
project_tangent!(obj.manifold,real_to_complex(obj,gradient(obj.inner_obj)),real_to_complex(obj,xin))
return gradient(obj.inner_obj)
end
function NLSolversBase.value_gradient!(obj::ManifoldObjective,x)
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x)))
value_gradient!(obj.inner_obj,xin)
project_tangent!(obj.manifold,real_to_complex(obj,gradient(obj.inner_obj)),real_to_complex(obj,xin))
return value(obj.inner_obj)
end

# fallback for out-of-place ops
project_tangent(M::Manifold,x) = project_tangent!(M, similar(x), x)
retract(M::Manifold,x) = retract!(M, copy(x))

# Flat manifold = {R,C}^n
# all the functions below are no-ops, and therefore the generated code
# for the flat manifold should be exactly the same as the one with all
# the manifold stuff removed
struct Flat <: Manifold
end
retract(M::Flat, x) = x
retract!(M::Flat,x) = x
project_tangent(M::Flat, g, x) = g
project_tangent!(M::Flat, g, x) = g

# {||x|| = 1}
struct Sphere <: Manifold
end
retract!(S::Sphere, x) = normalize!(x)
project_tangent!(S::Sphere,g,x) = (g .= g .- real(vecdot(x,g)).*x)

# N x n matrices such that X'X = I
# TODO: add more retractions, and support arbitrary inner product
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance you would want to open an issue for this? Just so we don't lose track of the TODO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, see #448

abstract type Stiefel <: Manifold end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not done all over the code base, but a simple reference or explanation of what the "Stiefel manifold" is would be nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an important but pretty special manifold, no? What is the justification for having it as part of Optim?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will document it. The main justification is that it is the one I need in my application ;-) More seriously, it's the basic manifold to do this kind of algorithms on: it was the original motivation for the theory, many other manifolds (sphere, O(n), U(n)) are special cases, it's probably the most used in applications (at least that I know of) outside of the sphere, and it's a good template for implementation of other manifolds.

There could be a Manifolds package living outside Optim, but it's a pretty short file so I would think this is fine, and people implementing other manifolds can just PR on Optim?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about the special cases. Ok maybe leave this for now and discuss moving outside of Optim only when somebody complains

struct Stiefel_CholQR <: Stiefel end
struct Stiefel_SVD <: Stiefel end
function Stiefel(retraction=:SVD)
if retraction == :CholQR
Stiefel_CholQR()
elseif retraction == :SVD
Stiefel_SVD()
end
end

function retract!(S::Stiefel_SVD, X)
U,S,V = svd(X)
X .= U*V'
end
function retract!(S::Stiefel_CholQR, X)
overlap = X'X
X .= X/chol(overlap)
end
project_tangent!(S::Stiefel, G, X) = (G .= X*(X'G .- G'X)./2 .+ G .- X*(X'G))



# TODO is there a better way of doing power and product manifolds?

# multiple copies of the same manifold. Points are arrays of arbitrary
# dimensions, and the first (given by inner_dims) are points of the
# inner manifold. E.g. the product of 2x2 Stiefel manifolds of dimension N x n
# would be a N x n x 2 x 2 matrix
struct PowerManifold<:Manifold
inner_manifold::Manifold #type of embedded manifold
inner_dims::Tuple #dimension of the embedded manifolds
outer_dims::Tuple #number of embedded manifolds
end
function retract!(m::PowerManifold, x)
for i=1:prod(m.outer_dims)
retract!(m.inner_manifold,get_inner(m, x, i))
end
x
end
function project_tangent!(m::PowerManifold, g, x)
for i=1:prod(m.outer_dims)
project_tangent!(m.inner_manifold,get_inner(m, g, i),get_inner(m, x, i))
end
g
end
# linear indexing
@inline function get_inner(m::PowerManifold, x, i::Int)
size_inner = prod(m.inner_dims)
size_outer = prod(m.outer_dims)
@assert 1 <= i <= size_outer
return reshape(view(x, (i-1)*size_inner+1:i*size_inner), m.inner_dims)
end
@inline get_inner(m::PowerManifold, x, i::Tuple) = get_inner(m, x, ind2sub(m.outer_dims, i...))

#Product of two manifolds {P = (x1,x2), x1 ∈ m1, x2 ∈ m2}.
#P is assumed to be a flat array, and x1 is before x2 in memory
struct ProductManifold<:Manifold
m1::Manifold
m2::Manifold
dims1::Tuple
dims2::Tuple
end
function retract!(m::ProductManifold, x)
retract!(m.m1, get_inner(m,x,1))
retract!(m.m2, get_inner(m,x,2))
x
end
function project_tangent!(m::ProductManifold, g, x)
project_tangent!(m.m1, get_inner(m, g, 1), get_inner(m, x, 1))
project_tangent!(m.m2, get_inner(m, g, 2), get_inner(m, x, 2))
g
end
function get_inner(m::ProductManifold, x, i)
N1 = prod(m.dims1)
N2 = prod(m.dims2)
@assert length(x) == N1+N2
if i == 1
return reshape(view(x, 1:N1),m.dims1)
elseif i == 2
return reshape(view(x, N1+1:N1+N2), m.dims2)
else
error("Only two components in a product manifold")
end
end
12 changes: 11 additions & 1 deletion src/Optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ module Optim
Base.getindex,
Base.setindex!

import NLSolversBase.iscomplex

export optimize,
NonDifferentiable,
OnceDifferentiable,
Expand All @@ -39,12 +41,20 @@ module Optim
Newton,
NewtonTrustRegion,
SimulatedAnnealing,
ParticleSwarm
ParticleSwarm,

Manifold,
Flat,
Sphere,
Stiefel

# Types
include("types.jl")
include("objective_types.jl")

# Manifolds
include("Manifolds.jl")

# Generic stuff
include("utilities/generic.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Base.summary(r::OptimizationResults) = summary(r.method) # might want to do more here than just return summary of the method used
minimizer(r::OptimizationResults) = r.minimizer
minimizer(r::OptimizationResults) = iscomplex(r) ? real_to_complex(r.minimizer) : r.minimizer
minimum(r::OptimizationResults) = r.minimum
iterations(r::OptimizationResults) = r.iterations
iteration_limit_reached(r::OptimizationResults) = r.iteration_converged
Expand Down
6 changes: 5 additions & 1 deletion src/multivariate/optimize/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ update_h!(d, state, method::SecondOrderSolver) = hessian!(d, state.x)
after_while!(d, state, method, options) = nothing

function optimize{D<:AbstractObjective, M<:Optimizer}(d::D, initial_x::AbstractArray, method::M,
options::Options = Options(), state = initial_state(method, options, d, initial_x))
options::Options = Options(), state = initial_state(method, options, d, complex_to_real(d, initial_x)))

t0 = time() # Initial time stamp used to control early stopping by options.time_limit

initial_x = complex_to_real(d, initial_x)

if length(initial_x) == 1 && typeof(method) <: NelderMead
error("Use optimize(f, scalar, scalar) for 1D problems")
end
Expand Down Expand Up @@ -70,6 +72,7 @@ function optimize{D<:AbstractObjective, M<:Optimizer}(d::D, initial_x::AbstractA

try
return MultivariateOptimizationResults(method,
NLSolversBase.iscomplex(d),
initial_x,
f_increased ? state.x_previous : state.x,
f_increased ? state.f_x_previous : value(d),
Expand All @@ -91,6 +94,7 @@ function optimize{D<:AbstractObjective, M<:Optimizer}(d::D, initial_x::AbstractA
h_calls(d))
catch
return MultivariateOptimizationResults(method,
NLSolversBase.iscomplex(d),
initial_x,
f_increased ? state.x_previous : state.x,
f_increased ? state.f_x_previous : value(d),
Expand Down
2 changes: 1 addition & 1 deletion src/multivariate/solvers/constrained/fminbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function optimize{T<:AbstractFloat,O<:Optimizer}(
results.x_converged, results.f_converged, results.g_converged, converged, f_increased = assess_convergence(x, xold, minimum(results), fval0, g, x_tol, f_tol, g_tol)
f_increased && !allow_f_increases && break
end
return MultivariateOptimizationResults(Fminbox{O}(), initial_x, minimizer(results), df.f(minimizer(results)),
return MultivariateOptimizationResults(Fminbox{O}(), false, initial_x, minimizer(results), df.f(minimizer(results)),
iteration, results.iteration_converged,
results.x_converged, results.x_tol, vecnorm(x - xold),
results.f_converged, results.f_tol, f_residual(minimum(results), fval0, f_tol),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# L should be function or any other callable
struct AcceleratedGradientDescent{L} <: Optimizer
linesearch!::L
manifold::Manifold
end

Base.summary(::AcceleratedGradientDescent) = "Accelerated Gradient Descent"
Expand All @@ -17,8 +18,8 @@ Base.summary(::AcceleratedGradientDescent) = "Accelerated Gradient Descent"
AcceleratedGradientDescent(; linesearch = LineSearches.HagerZhang()) =
AcceleratedGradientDescent(linesearch)
=#
function AcceleratedGradientDescent(; linesearch = LineSearches.HagerZhang())
AcceleratedGradientDescent(linesearch)
function AcceleratedGradientDescent(; linesearch = LineSearches.HagerZhang(), manifold::Manifold=Flat())
AcceleratedGradientDescent(linesearch, manifold)
end

mutable struct AcceleratedGradientDescentState{T,N}
Expand All @@ -33,9 +34,12 @@ mutable struct AcceleratedGradientDescentState{T,N}
end

function initial_state{T}(method::AcceleratedGradientDescent, options, d, initial_x::Array{T})
initial_x = copy(initial_x)
retract!(method.manifold, real_to_complex(d,initial_x))
value_gradient!(d, initial_x)
project_tangent!(method.manifold, real_to_complex(d,gradient(d)), real_to_complex(d,initial_x))

AcceleratedGradientDescentState(copy(initial_x), # Maintain current state in state.x
AcceleratedGradientDescentState(initial_x, # Maintain current state in state.x
copy(initial_x), # Maintain previous state in state.x_previous
T(NaN), # Store previous f in state.f_x_previous
0, # Iteration
Expand All @@ -48,13 +52,14 @@ end
function update_state!{T}(d, state::AcceleratedGradientDescentState{T}, method::AcceleratedGradientDescent)
n = length(state.x)
state.iteration += 1
project_tangent!(method.manifold, real_to_complex(d,gradient(d)), real_to_complex(d,state.x))
# Search direction is always the negative gradient
@simd for i in 1:n
@inbounds state.s[i] = -gradient(d, i)
end

# Determine the distance of movement along the search line
lssuccess = perform_linesearch!(state, method, d)
lssuccess = perform_linesearch!(state, method, ManifoldObjective(method.manifold, d))

# Record previous state
copy!(state.x_previous, state.x)
Expand All @@ -64,12 +69,14 @@ function update_state!{T}(d, state::AcceleratedGradientDescentState{T}, method::
@simd for i in 1:n
@inbounds state.y[i] = state.x[i] + state.alpha * state.s[i]
end
retract!(method.manifold, real_to_complex(d,state.y))

# Update current position with Nesterov correction
scaling = (state.iteration - 1) / (state.iteration + 2)
@simd for i in 1:n
@inbounds state.x[i] = state.y[i] + scaling * (state.y[i] - state.y_previous[i])
end
retract!(method.manifold, real_to_complex(d,state.x))

lssuccess == false # break on linesearch error
end
16 changes: 11 additions & 5 deletions src/multivariate/solvers/first_order/bfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ struct BFGS{L, H<:Function} <: Optimizer
linesearch!::L
initial_invH::H
resetalpha::Bool
manifold::Manifold
end

Base.summary(::BFGS) = "BFGS"
Expand All @@ -17,8 +18,8 @@ BFGS(; linesearch = LineSearches.HagerZhang(), initial_invH = x -> eye(eltype(x)
=#
function BFGS(; linesearch = LineSearches.HagerZhang(),
initial_invH = x -> eye(eltype(x), length(x)),
resetalpha = true)
BFGS(linesearch, initial_invH, resetalpha)
resetalpha = true, manifold::Manifold=Flat())
BFGS(linesearch, initial_invH, resetalpha, manifold)
end

mutable struct BFGSState{T,N,G}
Expand All @@ -36,10 +37,13 @@ end

function initial_state{T}(method::BFGS, options, d, initial_x::Array{T})
n = length(initial_x)
initial_x = copy(initial_x)
retract!(method.manifold, real_to_complex(d,initial_x))
value_gradient!(d, initial_x)
project_tangent!(method.manifold, real_to_complex(d,gradient(d)), real_to_complex(d,initial_x))
# Maintain a cache for line search results
# Trace the history of states visited
BFGSState(copy(initial_x), # Maintain current state in state.x
BFGSState(initial_x, # Maintain current state in state.x
similar(initial_x), # Maintain previous state in state.x_previous
copy(gradient(d)), # Store previous gradient in state.g_previous
T(NaN), # Store previous f in state.f_x_previous
Expand All @@ -59,21 +63,23 @@ function update_state!{T}(d, state::BFGSState{T}, method::BFGS)
# Search direction is the negative gradient divided by the approximate Hessian
A_mul_B!(vec(state.s), state.invH, vec(gradient(d)))
scale!(state.s, -1)
project_tangent!(method.manifold, real_to_complex(d,state.s), real_to_complex(d,state.x))

# Maintain a record of the previous gradient
copy!(state.g_previous, gradient(d))

# Determine the distance of movement along the search line
# This call resets invH to initial_invH is the former in not positive
# semi-definite
lssuccess = perform_linesearch!(state, method, d)
lssuccess = perform_linesearch!(state, method, ManifoldObjective(method.manifold, d))

# Maintain a record of previous position
copy!(state.x_previous, state.x)

# Update current position
state.dx .= state.alpha.*state.s
state.x .= state.x .+ state.dx
retract!(method.manifold, real_to_complex(d,state.x))
#
lssuccess == false # break on linesearch error
end
Expand All @@ -90,7 +96,7 @@ function update_h!(d, state, method::BFGS)
if dx_dg == 0.0
return true # force stop
end
A_mul_B!(state.u, state.invH, state.dg)
A_mul_B!(vec(state.u), state.invH, vec(state.dg))

c1 = (dx_dg + vecdot(state.dg, state.u)) / (dx_dg * dx_dg)
c2 = 1 / dx_dg
Expand Down
Loading