Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Implement truncate! function #16

Merged
merged 10 commits into from
Feb 23, 2024
39 changes: 39 additions & 0 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,45 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method
return tn
end

jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)

"""
truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing)

Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`.

# Notes

- Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used.
- The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`.
"""
function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing)
mofeing marked this conversation as resolved.
Show resolved Hide resolved
# TODO replace for select(:between)
vind = rightindex(qtn, bond[1])
if vind != leftindex(qtn, bond[2])
throw(ArgumentError("Invalid bond $bond"))
end

if vind ∉ inds(TensorNetwork(qtn), :hyper)
throw(ArgumentError("Can't access the spectrum on bond $bond"))
end

tensor = TensorNetwork(qtn)[vind]
spectrum = parent(tensor)

extent = if !isnothing(maxdim)
1:maxdim
elseif !isnothing(threshold)
findall(>(threshold) ∘ abs, spectrum)
else
throw(ArgumentError("Either `threshold` or `maxdim` must be provided"))
end

slice!(TensorNetwork(qtn), vind, extent)

return qtn
end

function isleftcanonical(qtn::Chain, site; atol::Real = 1e-12)
right_ind = rightindex(qtn, site)

Expand Down
2 changes: 1 addition & 1 deletion src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize_site, canonize_site!, truncate!
export leftindex, rightindex, canonize_site, canonize_site!, truncate, truncate!
export mixed_canonize, mixed_canonize!, isleftcanonical, isrightcanonical

# reexports from Tenet
Expand Down
18 changes: 18 additions & 0 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@
@test rightsite(qtn, Site(1)) == Site(2)
end

@testset "truncate" begin
qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)])
canonize_site!(qtn, Site(2); direction = :right, method = :svd)

@test_throws ArgumentError truncate!(qtn, [Site(1), Site(2)]; maxdim = 1)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
@test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)])

truncated = truncate(qtn, [Site(2), Site(3)]; maxdim = 1)
@test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1))
@test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2))
mofeing marked this conversation as resolved.
Show resolved Hide resolved

# TODO: Uncomment when `select(:between)` is working
# singular_values = select(qtn, :between, Site(2), Site(3))
# truncated = truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1)
# @test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1))
# @test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2))
end

@testset "Canonization" begin
using Tenet

Expand Down
Loading