Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This PR makes Gradient Descent parallelized using Threads.@spawn #179

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
82 changes: 65 additions & 17 deletions src/algorithms/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
module GrassmannMPS

using ..MPSKit
using ..Defaults
using TensorKit
import TensorKitManifolds.Grassmann

Expand Down Expand Up @@ -68,24 +69,44 @@
end

function ManifoldPoint(state::Union{InfiniteMPS,FiniteMPS}, envs)
al_d = similar(state.AL)
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
@static if Defaults.parallelize_sites
al_d = similar(state.AL)
@sync for i in 1:length(state)
Threads.@spawn begin
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
end
end

Check warning on line 78 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L78

Added line #L78 was not covered by tests
g = fetch.(map(CartesianIndices(state.AL)) do I
return Threads.@spawn Grassmann.project(al_d[I], state.AL[I])
end)
else
al_d = similar(state.AL)
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
end
g = Grassmann.project.(al_d, state.AL)
end

Gertian marked this conversation as resolved.
Show resolved Hide resolved
g = Grassmann.project.(al_d, state.AL)

Rhoreg = Vector{eltype(state.CR)}(undef, length(state))
δmin = sqrt(eps(real(scalartype(state))))
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
@static if Defaults.parallelize_sites

Check warning on line 92 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L92

Added line #L92 was not covered by tests
Gertian marked this conversation as resolved.
Show resolved Hide resolved
@sync for i in 1:length(state)
Threads.@spawn begin
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end

Check warning on line 97 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L97

Added line #L97 was not covered by tests
else
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end

return ManifoldPoint(state, envs, g, Rhoreg)
end

function ManifoldPoint(state::MPSMultiline, envs)
# FIXME: add support for unitcells
# TODO : then support parralelize_sites
@assert length(state.AL) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"

# TODO: this really should not use the operator from the environment
Expand Down Expand Up @@ -115,9 +136,16 @@
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
# the gradient I want to return is the preconditioned gradient!
g_prec = Vector{PrecGrad{eltype(x.g),eltype(x.Rhoreg)}}(undef, length(x.g))

for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
@static if Defaults.parallelize_sites

Check warning on line 139 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L139

Added line #L139 was not covered by tests
@sync for i in 1:length(x.state)
Threads.@spawn begin
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end

Check warning on line 144 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L144

Added line #L144 was not covered by tests
else
for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end

# TODO: the operator really should not be part of the environments, and this should
Expand All @@ -128,6 +156,7 @@
return real(f), g_prec
end
function fg(x::ManifoldPoint{<:MPSMultiline})
#TODO : support parralelize_sites
@assert length(x.state) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"
# the gradient I want to return is the preconditioned gradient!
g_prec = map(enumerate(x.g)) do (i, cg)
Expand All @@ -147,6 +176,7 @@
Retract a left-canonical MPSMultiline along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:MPSMultiline}, tg, alpha)
#TODO : support parralelize_sites
Gertian marked this conversation as resolved.
Show resolved Hide resolved
g = reshape(tg, size(x.state))

nal = similar(x.state.AL)
Expand All @@ -170,11 +200,19 @@
envs = x.envs
nal = similar(state.AL)
h = similar(g) # The tangent at the end-point
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
@static if Defaults.parallelize_sites

Check warning on line 203 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L203

Added line #L203 was not covered by tests
@sync for i in 1:length(g)
Threads.@spawn begin
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end

Check warning on line 209 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L209

Added line #L209 was not covered by tests
else
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end

nstate = InfiniteMPS(nal, state.CR[end])

newpoint = ManifoldPoint(nstate, envs)
Expand All @@ -186,6 +224,7 @@
Retract a left-canonical finite MPS along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:FiniteMPS}, g, alpha)
#TODO : support parralelize_sites.
state = x.state
envs = x.envs

Expand All @@ -208,9 +247,18 @@
`alpha`. `xp` is the end-point of the retraction.
"""
function transport!(h, x, g, alpha, xp)
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
@static if Defaults.parallelize_sites

Check warning on line 250 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L250

Added line #L250 was not covered by tests
@sync for i in 1:length(h)
Threads.@spawn begin
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end

Check warning on line 256 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L256

Added line #L256 was not covered by tests
else
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end
return h
end
Expand Down
Loading