Skip to content

Commit

Permalink
Refactor eigsolve to fixedpoint (#134)
Browse files Browse the repository at this point in the history
* define `fixedpoint` as `eigsolve` wrapper

* Replace `eigsolve` with `fixedpoint`

* Refactor multiline environments

* Fix typo

* Formatter
  • Loading branch information
lkdvos authored Mar 21, 2024
1 parent 11016b9 commit 91af3a2
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 101 deletions.
1 change: 1 addition & 0 deletions src/MPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ include("environments/multipleenv.jl")
include("environments/idmrgenv.jl")
include("environments/lazylincocache.jl")

include("algorithms/fixedpoint.jl")
include("algorithms/derivatives.jl")
include("algorithms/expval.jl")
include("algorithms/toolbox.jl")
Expand Down
30 changes: 30 additions & 0 deletions src/algorithms/fixedpoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# wrapper around KrylovKit.jl's eigsolve function

"""
fixedpoint(A, x₀, which::Symbol, alg) -> val, vec
Compute the fixedpoint of a linear operator `A` using the specified eigensolver `alg`. The
fixedpoint is assumed to be unique.
"""
function fixedpoint(A, x₀, which::Symbol, alg::Lanczos)
vals, vecs, info = eigsolve(A, x₀, 1, which, alg)

if info.converged == 0
@warn "fixedpoint not converged after $(info.numiter) iterations: normres = $(info.normres[1])"
end

return vals[1], vecs[1]
end

function fixedpoint(A, x₀, which::Symbol, alg::Arnoldi)
TT, vecs, vals, info = schursolve(A, x₀, 1, which, alg)

if info.converged == 0
@warn "fixedpoint not converged after $(info.numiter) iterations: normres = $(info.normres[1])"
end
if size(TT, 2) > 1 && TT[2, 1] != 0
@warn "non-unique fixedpoint detected"
end

return vals[1], vecs[1]
end
10 changes: 4 additions & 6 deletions src/algorithms/groundstate/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ function find_groundstate!(ψ::AbstractFiniteMPS, H, alg::DMRG, envs=environment
zerovector!(ϵs)
for pos in [1:(length(ψ) - 1); length(ψ):-1:2]
h = ∂∂AC(pos, ψ, H, envs)
_, vecs = eigsolve(h, ψ.AC[pos], 1, :SR, alg_eigsolve)
_, vec = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
ϵs[pos] = max(ϵs[pos], calc_galerkin(ψ, pos, envs))
ψ.AC[pos] = vecs[1]
ψ.AC[pos] = vec
end
ϵ = maximum(ϵs)

Expand Down Expand Up @@ -91,8 +91,7 @@ function find_groundstate!(ψ::AbstractFiniteMPS, H, alg::DMRG2, envs=environmen
for pos in 1:(length(ψ) - 1)
@plansor ac2[-1 -2; -3 -4] := ψ.AC[pos][-1 -2; 1] * ψ.AR[pos + 1][1 -4; -3]

_, vecs = eigsolve(∂∂AC2(pos, ψ, H, envs), ac2, 1, :SR, alg_eigsolve)
newA2center = first(vecs)
_, newA2center = fixedpoint(∂∂AC2(pos, ψ, H, envs), ac2, :SR, alg_eigsolve)

al, c, ar, = tsvd!(newA2center; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)
Expand All @@ -108,8 +107,7 @@ function find_groundstate!(ψ::AbstractFiniteMPS, H, alg::DMRG2, envs=environmen
for pos in (length(ψ) - 2):-1:1
@plansor ac2[-1 -2; -3 -4] := ψ.AL[pos][-1 -2; 1] * ψ.AC[pos + 1][1 -4; -3]

_, vecs = eigsolve(∂∂AC2(pos, ψ, H, envs), ac2, 1, :SR, alg_eigsolve)
newA2center = first(vecs)
_, newA2center = fixedpoint(∂∂AC2(pos, ψ, H, envs), ac2, :SR, alg_eigsolve)

al, c, ar, = tsvd!(newA2center; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)
Expand Down
34 changes: 15 additions & 19 deletions src/algorithms/groundstate/idmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,17 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG1, oenvs=environments(o
# left to right sweep
for pos in 1:length(ψ)
h = ∂∂AC(pos, ψ, H, envs)
_, vecs = eigsolve(h, ψ.AC[pos], 1, :SR, alg_eigsolve)

ψ.AC[pos] = vecs[1]
ψ.AL[pos], ψ.CR[pos] = leftorth!(vecs[1])

_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
ψ.AL[pos], ψ.CR[pos] = leftorth!.AC[pos])
update_leftenv!(envs, ψ, H, pos + 1)
end

# right to left sweep
for pos in length(ψ):-1:1
h = ∂∂AC(pos, ψ, H, envs)
_, vecs = eigsolve(h, ψ.AC[pos], 1, :SR, alg_eigsolve)
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)

ψ.AC[pos] = vecs[1]
ψ.CR[pos - 1], temp = rightorth!(_transpose_tail(vecs[1]))
ψ.CR[pos - 1], temp = rightorth!(_transpose_tail.AC[pos]))
ψ.AR[pos] = _transpose_front(temp)

update_rightenv!(envs, ψ, H, pos - 1)
Expand Down Expand Up @@ -112,9 +108,9 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG2, oenvs=environments(o
for pos in 1:(length(ψ) - 1)
ac2 = ψ.AC[pos] * _transpose_tail.AR[pos + 1])
h_ac2 = ∂∂AC2(pos, ψ, H, envs)
_, vecs, _ = eigsolve(h_ac2, ac2, 1, :SR, alg_eigsolve)
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)

al, c, ar, = tsvd!(vecs[1]; trunc=alg.trscheme, alg=TensorKit.SVD())
al, c, ar, = tsvd!(ac2′; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)

ψ.AL[pos] = al
Expand All @@ -130,9 +126,9 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG2, oenvs=environments(o
@plansor ac2[-1 -2; -3 -4] := ψ.AC[end][-1 -2; 1] * inv.CR[0])[1; 2] *
ψ.AL[1][2 -4; 3] * ψ.CR[1][3; -3]
h_ac2 = ∂∂AC2(0, ψ, H, envs)
_, vecs, _ = eigsolve(h_ac2, ac2, 1, :SR, alg_eigsolve)
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)

al, c, ar, = tsvd!(vecs[1]; trunc=alg.trscheme, alg=TensorKit.SVD())
al, c, ar, = tsvd!(ac2′; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)

ψ.AC[end] = al * c
Expand All @@ -152,9 +148,9 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG2, oenvs=environments(o
for pos in (length(ψ) - 1):-1:1
ac2 = ψ.AL[pos] * _transpose_tail.AC[pos + 1])
h_ac2 = ∂∂AC2(pos, ψ, H, envs)
_, vecs, _ = eigsolve(h_ac2, ac2, 1, :SR, alg_eigsolve)
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)

al, c, ar, = tsvd!(vecs[1]; trunc=alg.trscheme, alg=TensorKit.SVD())
al, c, ar, = tsvd!(ac2′; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)

ψ.AL[pos] = al
Expand All @@ -171,8 +167,8 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG2, oenvs=environments(o
@plansor ac2[-1 -2; -3 -4] := ψ.CR[end - 1][-1; 1] * ψ.AR[end][1 -2; 2] *
inv.CR[end])[2; 3] * ψ.AC[1][3 -4; -3]
h_ac2 = ∂∂AC2(0, ψ, H, envs)
_, vecs, = eigsolve(h_ac2, ac2, 1, :SR, alg_eigsolve)
al, c, ar, = tsvd!(vecs[1]; trunc=alg.trscheme, alg=TensorKit.SVD())
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
al, c, ar, = tsvd!(ac2′; trunc=alg.trscheme, alg=TensorKit.SVD())
normalize!(c)

ψ.AR[end] = _transpose_front(inv.CR[end - 1]) * _transpose_tail(al * c))
Expand Down Expand Up @@ -202,7 +198,7 @@ function find_groundstate(ost::InfiniteMPS, H, alg::IDMRG2, oenvs=environments(o
end
end

nst = InfiniteMPS.AR[1:end]; tol=alg.tol_gauge)
nenvs = environments(nst, H; solver=oenvs.solver)
return nst, nenvs, ϵ
ψ′ = InfiniteMPS.AR[1:end]; tol=alg.tol_gauge)
nenvs = environments(ψ′, H; solver=oenvs.solver)
return ψ′, nenvs, ϵ
end
12 changes: 5 additions & 7 deletions src/algorithms/groundstate/vumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,18 @@ function find_groundstate(ψ::InfiniteMPS, H, alg::VUMPS, envs=environments(ψ,
end

function _vumps_localupdate(loc, ψ, H, envs, eigalg, factalg=QRpos())
AC′, C′ = @static if Defaults.parallelize_sites
@static if Defaults.parallelize_sites
@sync begin
Threads.@spawn begin
_, acvecs = eigsolve(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], 1, :SR, eigalg)
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
end
Threads.@spawn begin
_, crvecs = eigsolve(∂∂C(loc, ψ, H, envs), ψ.CR[loc], 1, :SR, eigalg)
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.CR[loc], :SR, eigalg)
end
end
acvecs[1], crvecs[1]
else
_, acvecs = eigsolve(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], 1, :SR, eigalg)
_, crvecs = eigsolve(∂∂C(loc, ψ, H, envs), ψ.CR[loc], 1, :SR, eigalg)
acvecs[1], crvecs[1]
_, AC′ = fixedpoint(∂∂AC(loc, ψ, H, envs), ψ.AC[loc], :SR, eigalg)
_, C′ = fixedpoint(∂∂C(loc, ψ, H, envs), ψ.CR[loc], :SR, eigalg)
end
return regauge!(AC′, C′; alg=factalg)
end
16 changes: 8 additions & 8 deletions src/algorithms/statmech/vumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,28 @@ function leading_boundary(ψ::MPSMultiline, H, alg::VUMPS, envs=environments(ψ,
Threads.@spawn begin
H_AC = ∂∂AC($col, $ψ, $H, $envs)
ac = RecursiveVec($ψ.AC[:, col])
_, acvecs = eigsolve(H_AC, ac, 1, :LM, alg_eigsolve)
$temp_ACs[:, col] = acvecs[1].vecs[:]
_, ac′ = fixedpoint(H_AC, ac, :LM, alg_eigsolve)
$temp_ACs[:, col] = ac′.vecs[:]
end

Threads.@spawn begin
H_C = ∂∂C($col, $ψ, $H, $envs)
c = RecursiveVec($ψ.CR[:, col])
_, cvecs = eigsolve(H_C, c, 1, :LM, alg_eigsolve)
$temp_Cs[:, col] = cvecs[1].vecs[:]
_, c′ = fixedpoint(H_C, c, :LM, alg_eigsolve)
$temp_Cs[:, col] = c′.vecs[:]
end
end
else
for col in 1:size(ψ, 2)
H_AC = ∂∂AC(col, ψ, H, envs)
ac = RecursiveVec.AC[:, col])
_, acvecs = eigsolve(H_AC, ac, 1, :LM, alg_eigsolve)
temp_ACs[:, col] = acvecs[1].vecs[:]
_, ac′ = fixedpoint(H_AC, ac, :LM, alg_eigsolve)
temp_ACs[:, col] = ac′.vecs[:]

H_C = ∂∂C(col, ψ, H, envs)
c = RecursiveVec.CR[:, col])
_, cvecs = eigsolve(H_C, c, 1, :LM, alg_eigsolve)
temp_Cs[:, col] = cvecs[1].vecs[:]
_, c′ = fixedpoint(H_C, c, :LM, alg_eigsolve)
temp_Cs[:, col] = c′.vecs[:]
end
end

Expand Down
84 changes: 35 additions & 49 deletions src/environments/permpoinfenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,67 +137,53 @@ end

function mixed_fixpoints(above::MPSMultiline, mpo::MPOMultiline, below::MPSMultiline,
init=gen_init_fps(above, mpo, below); solver=Defaults.eigsolver)
T = eltype(above)

#sanity check
(numrows, numcols) = size(above)
# sanity check
numrows, numcols = size(above)
@assert size(above) == size(mpo)
@assert size(below) == size(mpo)

envtype = eltype(init[1])
lefties = PeriodicArray{envtype,2}(undef, numrows, numcols)
righties = PeriodicArray{envtype,2}(undef, numrows, numcols)

@threads for cr in 1:numrows
c_above = above[cr]
c_below = below[cr + 1]

(L0, R0) = init[cr]

GLs = PeriodicArray{envtype,2}(undef, numrows, numcols)
GRs = PeriodicArray{envtype,2}(undef, numrows, numcols)

@threads for row in 1:numrows
Os = mpo[row, :]
ALs_top, ALs_bot = above[row].AL, below[row + 1].AL
ARs_top, ARs_bot = above[row].AR, below[row + 1].AR
L0, R0 = init[row]
@sync begin
Threads.@spawn begin
E_LL = TransferMatrix($c_above.AL, $mpo[cr, :], $c_below.AL)

packed_init = $L0 isa Vector ? RecursiveVec($L0) : $L0
(_, Ls, convhist) = eigsolve(flip(E_LL), packed_init, 1, :LM, $solver)
convhist.converged < 1 &&
@warn "GL failed to converge: normres = $(convhist.normres)"
L0 = $L0 isa Vector ? Ls[1].vecs : Ls[1]
E_LL = TransferMatrix(ALs_top, Os, ALs_bot)
_, GLs[row, 1] = fixedpoint(flip(E_LL), L0, :LM, solver)
# compute rest of unitcell
for col in 2:numcols
GLs[row, col] = GLs[row, col - 1] *
TransferMatrix(ALs_top[col - 1], Os[col - 1],
ALs_bot[col - 1])
end
end

Threads.@spawn begin
packed_init = $R0 isa Vector ? RecursiveVec($R0) : $R0
E_RR = TransferMatrix($c_above.AR, $mpo[cr, :], $c_below.AR)
(_, Rs, convhist) = eigsolve(E_RR, packed_init, 1, :LM, $solver)
convhist.converged < 1 &&
@warn "GR failed to converge: normres = $(convhist.normres)"
R0 = $R0 isa Vector ? Rs[1].vecs : Rs[1]
E_RR = TransferMatrix(ARs_top, Os, ARs_bot)
_, GRs[row, end] = fixedpoint(E_RR, R0, :LM, solver)
# compute rest of unitcell
for col in (numcols - 1):-1:1
GRs[row, col] = TransferMatrix(ARs_top[col + 1], Os[col + 1],
ARs_bot[col + 1]) *
GRs[row, col + 1]
end
end
end

lefties[cr, 1] = L0
for loc in 2:numcols
lefties[cr, loc] = lefties[cr, loc - 1] *
TransferMatrix(c_above.AL[loc - 1], mpo[cr, loc - 1],
c_below.AL[loc - 1])
end

renormfact::scalartype(T) = dot(c_below.CR[0], MPO_∂∂C(L0, R0) * c_above.CR[0])

righties[cr, end] = R0 / sqrt(renormfact)
lefties[cr, 1] /= sqrt(renormfact)

for loc in (numcols - 1):-1:1
righties[cr, loc] = TransferMatrix(c_above.AR[loc + 1], mpo[cr, loc + 1],
c_below.AR[loc + 1]) *
righties[cr, loc + 1]

renormfact = dot(c_below.CR[loc],
MPO_∂∂C(lefties[cr, loc + 1], righties[cr, loc]) *
c_above.CR[loc])
righties[cr, loc] /= sqrt(renormfact)
lefties[cr, loc + 1] /= sqrt(renormfact)
# fix normalization
CRs_top, CRs_bot = above[row].CR, below[row + 1].CR
for col in 1:numcols
λ = dot(CRs_bot[col],
MPO_∂∂C(GLs[row, col + 1], GRs[row, col]) * CRs_top[col])
scale!(GLs[row, col + 1], 1 / sqrt(λ))
scale!(GRs[row, col], 1 / sqrt(λ))
end
end

return (lefties, righties)
return GLs, GRs
end
7 changes: 3 additions & 4 deletions src/operators/densempo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ function TensorKit.dot(a::InfiniteMPS, mpo::DenseMPO, b::InfiniteMPS; krylovdim=
_firstspace(b.AL[1]) * _firstspace(mpo.opp[1]) _firstspace(a.AL[1]))
randomize!(init)

(vals, vecs, convhist) = eigsolve(TransferMatrix(b.AL, mpo.opp, a.AL), init, 1, :LM,
Arnoldi(; krylovdim=krylovdim))
convhist.converged == 0 && @warn "dot failed to converge: normres = $(convhist.normres)"
return vals[1]
val, = fixedpoint(TransferMatrix(b.AL, mpo.opp, a.AL), init, :LM,
Arnoldi(; krylovdim=krylovdim))
return val
end
7 changes: 3 additions & 4 deletions src/states/infinitemps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,9 @@ end
function TensorKit.dot(ψ₁::InfiniteMPS, ψ₂::InfiniteMPS; krylovdim=30)
init = similar(ψ₁.AL[1], _firstspace(ψ₂.AL[1]) _firstspace(ψ₁.AL[1]))
randomize!(init)
vals, _, convhist = eigsolve(TransferMatrix(ψ₂.AL, ψ₁.AL), init, 1, :LM,
Arnoldi(; krylovdim=krylovdim))
convhist.converged == 0 && @warn "dot failed to converge: normres = $(convhist.normres)"
return vals[1]
val, = fixedpoint(TransferMatrix(ψ₂.AL, ψ₁.AL), init, :LM,
Arnoldi(; krylovdim=krylovdim))
return val
end

function Base.show(io::IO, ::MIME"text/plain", ψ::InfiniteMPS)
Expand Down
8 changes: 4 additions & 4 deletions src/states/ortho.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ function gauge_eigsolve_step!(it::IterativeSolver{LeftCanonical}, state)
(; AL, CR, A, iter, ϵ) = state
if iter it.eig_miniter
alg_eigsolve = updatetol(it.alg_eigsolve, 1, ϵ^2)
_, vecs = eigsolve(flip(TransferMatrix(A, AL)), CR[end], 1, :LM, alg_eigsolve)
_, CR[end] = leftorth!(vecs[1]; alg=it.alg_orth)
_, vec = fixedpoint(flip(TransferMatrix(A, AL)), CR[end], :LM, alg_eigsolve)
_, CR[end] = leftorth!(vec; alg=it.alg_orth)
end
return CR[end]
end
Expand Down Expand Up @@ -238,8 +238,8 @@ function gauge_eigsolve_step!(it::IterativeSolver{RightCanonical}, state)
(; AR, CR, A, iter, ϵ) = state
if iter it.eig_miniter
alg_eigsolve = updatetol(it.alg_eigsolve, 1, ϵ^2)
_, vecs = eigsolve(TransferMatrix(A, AR), CR[end], 1, :LM, alg_eigsolve)
CR[end], _ = rightorth!(vecs[1]; alg=it.alg_orth)
_, vec = fixedpoint(TransferMatrix(A, AR), CR[end], :LM, alg_eigsolve)
CR[end], _ = rightorth!(vec; alg=it.alg_orth)
end
return CR[end]
end
Expand Down

0 comments on commit 91af3a2

Please sign in to comment.