-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Manifold optimization #435
Changes from 11 commits
2b9b7f9
9aa58c1
0ea8ff5
765e574
7e64c2a
dd106e2
74cd17b
02f05f6
7b7002e
1f8315e
2aadf79
cff44ec
583997f
1feb7b6
1bb787b
018cc81
0cba7ee
2878095
376178b
e71a6ed
2ae30ee
4cafc1a
e9c8566
c1f1bc5
0d8a1ac
44a186b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Manifold interface: every manifold (subtype of Manifold) defines the functions | ||
# project_tangent!(m, g, x): project g on the tangent space to m at x | ||
# retract!(m, x): map x back to a point on the manifold m | ||
|
||
## To add: | ||
## * Second order algorithms | ||
## * Vector transport | ||
## * Arbitrary inner product | ||
## * More retractions | ||
## * More manifolds from ROPTLIB | ||
## * {x, Ax = b} | ||
## * Intersection manifold (just do the projection on both manifolds iteratively and hope it converges) | ||
|
||
abstract type Manifold | ||
end | ||
|
||
|
||
type ManifoldObjective{T<:NLSolversBase.AbstractObjective} <: NLSolversBase.AbstractObjective | ||
manifold :: Manifold | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the space around |
||
inner_obj :: T | ||
end | ||
iscomplex(obj::ManifoldObjective) = iscomplex(obj.inner_obj) | ||
# TODO is it safe here to call retract! and change x? | ||
function NLSolversBase.value!(obj::ManifoldObjective, x) | ||
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x))) | ||
value!(obj.inner_obj, xin) | ||
end | ||
function NLSolversBase.value(obj::ManifoldObjective) | ||
value(obj.inner_obj) | ||
end | ||
function NLSolversBase.gradient(obj::ManifoldObjective) | ||
gradient(obj.inner_obj) | ||
end | ||
function NLSolversBase.gradient(obj::ManifoldObjective,i::Int) | ||
gradient(obj.inner_obj,i) | ||
end | ||
function NLSolversBase.gradient!(obj::ManifoldObjective,x) | ||
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x))) | ||
gradient!(obj.inner_obj,xin) | ||
project_tangent!(obj.manifold,real_to_complex(obj,gradient(obj.inner_obj)),real_to_complex(obj,xin)) | ||
return gradient(obj.inner_obj) | ||
end | ||
function NLSolversBase.value_gradient!(obj::ManifoldObjective,x) | ||
xin = complex_to_real(obj, retract(obj.manifold, real_to_complex(obj,x))) | ||
value_gradient!(obj.inner_obj,xin) | ||
project_tangent!(obj.manifold,real_to_complex(obj,gradient(obj.inner_obj)),real_to_complex(obj,xin)) | ||
return value(obj.inner_obj) | ||
end | ||
|
||
# fallback for out-of-place ops | ||
project_tangent(M::Manifold,x) = project_tangent!(M, similar(x), x) | ||
retract(M::Manifold,x) = retract!(M, copy(x)) | ||
|
||
# Flat manifold = {R,C}^n | ||
# all the functions below are no-ops, and therefore the generated code | ||
# for the flat manifold should be exactly the same as the one with all | ||
# the manifold stuff removed | ||
struct Flat <: Manifold | ||
end | ||
retract(M::Flat, x) = x | ||
retract!(M::Flat,x) = x | ||
project_tangent(M::Flat, g, x) = g | ||
project_tangent!(M::Flat, g, x) = g | ||
|
||
# {||x|| = 1} | ||
struct Sphere <: Manifold | ||
end | ||
retract!(S::Sphere, x) = normalize!(x) | ||
project_tangent!(S::Sphere,g,x) = (g .= g .- real(vecdot(x,g)).*x) | ||
|
||
# N x n matrices such that X'X = I | ||
# TODO: add more retractions, and support arbitrary inner product | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any chance you would want to open an issue for this? Just so we don't lose track of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, see #448 |
||
abstract type Stiefel <: Manifold end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it's not done all over the code base, but a simple reference or explanation of what the "Stiefel manifold" is would be nice. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's an important but pretty special manifold, no? What is the justification for having it as part of Optim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will document it. The main justification is that it is the one I need in my application ;-) More seriously, it's the basic manifold to do this kind of algorithms on: it was the original motivation for the theory, many other manifolds (sphere, O(n), U(n)) are special cases, it's probably the most used in applications (at least that I know of) outside of the sphere, and it's a good template for implementation of other manifolds. There could be a Manifolds package living outside Optim, but it's a pretty short file so I would think this is fine, and people implementing other manifolds can just PR on Optim? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point about the special cases. Ok maybe leave this for now and discuss moving outside of Optim only when somebody complains |
||
struct Stiefel_CholQR <: Stiefel end | ||
struct Stiefel_SVD <: Stiefel end | ||
function Stiefel(retraction=:SVD) | ||
if retraction == :CholQR | ||
Stiefel_CholQR() | ||
elseif retraction == :SVD | ||
Stiefel_SVD() | ||
end | ||
end | ||
|
||
function retract!(S::Stiefel_SVD, X) | ||
U,S,V = svd(X) | ||
X .= U*V' | ||
end | ||
function retract!(S::Stiefel_CholQR, X) | ||
overlap = X'X | ||
X .= X/chol(overlap) | ||
end | ||
project_tangent!(S::Stiefel, G, X) = (G .= X*(X'G .- G'X)./2 .+ G .- X*(X'G)) | ||
|
||
|
||
|
||
# TODO is there a better way of doing power and product manifolds? | ||
|
||
# multiple copies of the same manifold. Points are arrays of arbitrary | ||
# dimensions, and the first (given by inner_dims) are points of the | ||
# inner manifold. E.g. the product of 2x2 Stiefel manifolds of dimension N x n | ||
# would be a N x n x 2 x 2 matrix | ||
struct PowerManifold<:Manifold | ||
inner_manifold::Manifold #type of embedded manifold | ||
inner_dims::Tuple #dimension of the embedded manifolds | ||
outer_dims::Tuple #number of embedded manifolds | ||
end | ||
function retract!(m::PowerManifold, x) | ||
for i=1:prod(m.outer_dims) | ||
retract!(m.inner_manifold,get_inner(m, x, i)) | ||
end | ||
x | ||
end | ||
function project_tangent!(m::PowerManifold, g, x) | ||
for i=1:prod(m.outer_dims) | ||
project_tangent!(m.inner_manifold,get_inner(m, g, i),get_inner(m, x, i)) | ||
end | ||
g | ||
end | ||
# linear indexing | ||
@inline function get_inner(m::PowerManifold, x, i::Int) | ||
size_inner = prod(m.inner_dims) | ||
size_outer = prod(m.outer_dims) | ||
@assert 1 <= i <= size_outer | ||
return reshape(view(x, (i-1)*size_inner+1:i*size_inner), m.inner_dims) | ||
end | ||
@inline get_inner(m::PowerManifold, x, i::Tuple) = get_inner(m, x, ind2sub(m.outer_dims, i...)) | ||
|
||
#Product of two manifolds {P = (x1,x2), x1 ∈ m1, x2 ∈ m2}. | ||
#P is assumed to be a flat array, and x1 is before x2 in memory | ||
struct ProductManifold<:Manifold | ||
m1::Manifold | ||
m2::Manifold | ||
dims1::Tuple | ||
dims2::Tuple | ||
end | ||
function retract!(m::ProductManifold, x) | ||
retract!(m.m1, get_inner(m,x,1)) | ||
retract!(m.m2, get_inner(m,x,2)) | ||
x | ||
end | ||
function project_tangent!(m::ProductManifold, g, x) | ||
project_tangent!(m.m1, get_inner(m, g, 1), get_inner(m, x, 1)) | ||
project_tangent!(m.m2, get_inner(m, g, 2), get_inner(m, x, 2)) | ||
g | ||
end | ||
function get_inner(m::ProductManifold, x, i) | ||
N1 = prod(m.dims1) | ||
N2 = prod(m.dims2) | ||
@assert length(x) == N1+N2 | ||
if i == 1 | ||
return reshape(view(x, 1:N1),m.dims1) | ||
elseif i == 2 | ||
return reshape(view(x, N1+1:N1+N2), m.dims2) | ||
else | ||
error("Only two components in a product manifold") | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to add docstrings at least to all the new types