Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Implement truncate! function #16

Merged
merged 10 commits into from
Feb 23, 2024
10 changes: 10 additions & 0 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,13 @@ function Tenet.select(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site)
tensor === tensor2
end |> only
end

struct MissingSchmidtCoefficientsException <: Base.Exception
bond::NTuple{2,Site}
end

MissingSchmidtCoefficientsException(bond::Vector{Site}) = MissingSchmidtCoefficientsException(tuple(bond...))

function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException)
print(io, "Can't access the spectrum on bond $(e.bond)")
end
39 changes: 39 additions & 0 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,45 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method
return tn
end

jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)

"""
truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing)

Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`.

# Notes

- Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used.
- The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`.
"""
function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing)
mofeing marked this conversation as resolved.
Show resolved Hide resolved
# TODO replace for select(:between)
vind = rightindex(qtn, bond[1])
if vind != leftindex(qtn, bond[2])
throw(ArgumentError("Invalid bond $bond"))
end

if vind ∉ inds(TensorNetwork(qtn), :hyper)
throw(MissingSchmidtCoefficientsException(bond))
end

tensor = TensorNetwork(qtn)[vind]
spectrum = parent(tensor)

extent = if !isnothing(maxdim)
1:maxdim
elseif !isnothing(threshold)
findall(>(threshold) ∘ abs, spectrum)
else
throw(ArgumentError("Either `threshold` or `maxdim` must be provided"))
end

slice!(TensorNetwork(qtn), vind, extent)

return qtn
end

function isleftcanonical(qtn::Chain, site; atol::Real = 1e-12)
right_ind = rightindex(qtn, site)

Expand Down
4 changes: 2 additions & 2 deletions src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize_site, canonize_site!, truncate!
export mixed_canonize, mixed_canonize!, isleftcanonical, isrightcanonical
export leftindex, rightindex, isleftcanonical, isrightcanonical
export canonize_site, canonize_site!, truncate!, mixed_canonize, mixed_canonize!

# reexports from Tenet
using Tenet
Expand Down
18 changes: 18 additions & 0 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@
@test rightsite(qtn, Site(1)) == Site(2)
end

@testset "truncate" begin
qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)])
canonize_site!(qtn, Site(2); direction = :right, method = :svd)

@test_throws Qrochet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim = 1)
@test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)])

truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; maxdim = 1)
@test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1
@test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1

# TODO: Uncomment when `select(:between)` is working
# singular_values = select(qtn, :between, Site(2), Site(3))
# truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1)
# @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1
# @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1
end

@testset "Canonization" begin
using Tenet

Expand Down
Loading