From 972cdf5d70b3a40666351211cd3401309ba257b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 21 Feb 2024 23:58:31 +0100 Subject: [PATCH 1/9] Implement `truncate!` function --- src/Ansatz/Chain.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index f366d04..ff186f0 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -187,6 +187,33 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode = return tn end +function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) + # TODO replace for select(:between) + vind = rightindex(qtn, bond[1]) + if vind != leftindex(qtn, bond[2]) + throw(ArgumentError("Invalid bond $bond")) + end + + if vind ∉ inds(TensorNetwork(qtn), :hyper) + throw(ArgumentError("Can't access the spectrum on bond $bond")) + end + + tensor = TensorNetwork(qtn)[vind] + spectrum = parent(tensor) + + extent = if !isnothing(maxdim) + 1:maxdim + elseif !isnothing(threshold) + findall(>(threshold) ∘ abs, spectrum) + else + throw(ArgumentError("Either `threshold` or `maxdim` must be provided")) + end + + slice!(TensorNetwork(qtn), vind, extent) + + return qtn +end + mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...) From 2a4d836cf6d1be283f48ebce6f3b4e7861f1acfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 22 Feb 2024 16:30:51 +0100 Subject: [PATCH 2/9] Document `truncate!` --- src/Ansatz/Chain.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index ff186f0..853512a 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -187,6 +187,16 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode = return tn end +""" + truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) + +Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. + +# Notes + + - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. + - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. +""" function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) # TODO replace for select(:between) vind = rightindex(qtn, bond[1]) From 68364139993d090c7042ffd9c6cca81194243701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 23 Feb 2024 10:32:38 +0100 Subject: [PATCH 3/9] Add tests --- src/Ansatz/Chain.jl | 2 ++ src/Qrochet.jl | 2 +- test/Ansatz/Chain_test.jl | 18 ++++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 6826153..36a608d 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -187,6 +187,8 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method return tn end +truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) + """ truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index e80da4a..4f6e914 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -17,7 +17,7 @@ export Product include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO -export leftindex, rightindex, canonize_site, canonize_site!, truncate! +export leftindex, rightindex, canonize_site, canonize_site!, truncate, truncate! export mixed_canonize, mixed_canonize!, isleftcanonical, isrightcanonical # reexports from Tenet diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 4f1c121..07203c9 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -51,6 +51,24 @@ @test rightsite(qtn, Site(1)) == Site(2) end + @testset "truncate" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(qtn, Site(2); direction = :right, method = :svd) + + @test_throws ArgumentError truncate!(qtn, [Site(1), Site(2)]; maxdim = 1) + @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) + + truncated = truncate(qtn, [Site(2), Site(3)]; maxdim = 1) + @test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1)) + @test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2)) + + # TODO: Uncomment when `select(:between)` is working + # singular_values = select(qtn, :between, Site(2), Site(3)) + # truncated = truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1) + # @test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1)) + # @test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2)) + end + @testset "Canonization" begin using Tenet From bd974b9b45b5c323eea7ec3e784212ec0f493c94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 23 Feb 2024 11:05:19 +0100 Subject: [PATCH 4/9] Update tests using rightindex and leftindex functions --- test/Ansatz/Chain_test.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 07203c9..ac0ce7d 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -59,14 +59,14 @@ @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) truncated = truncate(qtn, [Site(2), Site(3)]; maxdim = 1) - @test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1)) - @test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2)) + @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 + @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 # TODO: Uncomment when `select(:between)` is working # singular_values = select(qtn, :between, Site(2), Site(3)) # truncated = truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1) - # @test issetequal(size(select(truncated, :tensor, Site(2))), (2, 2, 1)) - # @test issetequal(size(select(truncated, :tensor, Site(3))), (1, 2, 2)) + # @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 + # @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 end @testset "Canonization" begin From 187f6d896a565040b04d707a7c5d8c373b100c69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Fri, 23 Feb 2024 11:37:18 +0100 Subject: [PATCH 5/9] Add MissingSchmidtCoefficientsException --- src/Ansatz/Chain.jl | 8 +++++++- test/Ansatz/Chain_test.jl | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 36a608d..df8050d 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -189,6 +189,12 @@ end truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) +struct MissingSchmidtCoefficientsException <: Base.Exception + message::String +end + +Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) = print(io, "MissingSchmidtCoefficientsException: $(e.message)") + """ truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) @@ -207,7 +213,7 @@ function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, m end if vind ∉ inds(TensorNetwork(qtn), :hyper) - throw(ArgumentError("Can't access the spectrum on bond $bond")) + throw(MissingSchmidtCoefficientsException("Can't access the spectrum on bond $bond")) end tensor = TensorNetwork(qtn)[vind] diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index ac0ce7d..d8736bf 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -55,7 +55,7 @@ qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) canonize_site!(qtn, Site(2); direction = :right, method = :svd) - @test_throws ArgumentError truncate!(qtn, [Site(1), Site(2)]; maxdim = 1) + @test_throws Qrochet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim = 1) @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) truncated = truncate(qtn, [Site(2), Site(3)]; maxdim = 1) From 8b8e0efbeae8186edd8971dbc8d9d62890a68ebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 23 Feb 2024 11:16:47 +0100 Subject: [PATCH 6/9] Format exports --- src/Qrochet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 4f6e914..12b10ae 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -17,8 +17,8 @@ export Product include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO -export leftindex, rightindex, canonize_site, canonize_site!, truncate, truncate! -export mixed_canonize, mixed_canonize!, isleftcanonical, isrightcanonical +export leftindex, rightindex, isleftcanonical, isrightcanonical +export canonize_site, canonize_site!, truncate, truncate!, mixed_canonize, mixed_canonize! # reexports from Tenet using Tenet From cf70e7f0e314a357b68657e96da19307d3b5657b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 23 Feb 2024 11:50:07 +0100 Subject: [PATCH 7/9] Refactor `MissingSchmidtCoefficientsException` --- src/Ansatz.jl | 8 ++++++++ src/Ansatz/Chain.jl | 8 +------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 639b910..fe5f1be 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -47,3 +47,11 @@ function Tenet.select(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) tensor === tensor2 end |> only end + +struct MissingSchmidtCoefficientsException <: Base.Exception + bond::NTuple{2,Site} +end + +function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) + print(io, "Can't access the spectrum on bond $(e.bond)") +end diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index df8050d..a5cdbbb 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -189,12 +189,6 @@ end truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) -struct MissingSchmidtCoefficientsException <: Base.Exception - message::String -end - -Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) = print(io, "MissingSchmidtCoefficientsException: $(e.message)") - """ truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) @@ -213,7 +207,7 @@ function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, m end if vind ∉ inds(TensorNetwork(qtn), :hyper) - throw(MissingSchmidtCoefficientsException("Can't access the spectrum on bond $bond")) + throw(MissingSchmidtCoefficientsException(bond)) end tensor = TensorNetwork(qtn)[vind] From b01abad69eecb7817f64965806133c872fd81199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 23 Feb 2024 14:34:20 +0100 Subject: [PATCH 8/9] Fix clash between `truncate` and `Base.truncate` symbols --- src/Qrochet.jl | 2 +- test/Ansatz/Chain_test.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 12b10ae..31ae144 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -18,7 +18,7 @@ include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO export leftindex, rightindex, isleftcanonical, isrightcanonical -export canonize_site, canonize_site!, truncate, truncate!, mixed_canonize, mixed_canonize! +export canonize_site, canonize_site!, truncate!, mixed_canonize, mixed_canonize! # reexports from Tenet using Tenet diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index d8736bf..a4c5bad 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -58,13 +58,13 @@ @test_throws Qrochet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim = 1) @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) - truncated = truncate(qtn, [Site(2), Site(3)]; maxdim = 1) + truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; maxdim = 1) @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 # TODO: Uncomment when `select(:between)` is working # singular_values = select(qtn, :between, Site(2), Site(3)) - # truncated = truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1) + # truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2]+0.1) # @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 # @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 end From c70f4d6ebe23392c3368fdafd6ada0188cb0f66e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 23 Feb 2024 14:40:35 +0100 Subject: [PATCH 9/9] Fix `MissingSchmidtCoefficientsException` constructor with `Vector{Site}` --- src/Ansatz.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index fe5f1be..81bce91 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -52,6 +52,8 @@ struct MissingSchmidtCoefficientsException <: Base.Exception bond::NTuple{2,Site} end +MissingSchmidtCoefficientsException(bond::Vector{Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) + function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) print(io, "Can't access the spectrum on bond $(e.bond)") end