From 163ba6ecee6ede13198aa77376cae7e0764debbc Mon Sep 17 00:00:00 2001 From: Karl Pierce Date: Tue, 25 Jul 2023 11:23:05 -0400 Subject: [PATCH] [NDTensors] [BUG] Fix bug in in-place contract (#1158) --- NDTensors/src/dense/tensoralgebra/contract.jl | 14 +++++--------- test/base/test_contract.jl | 4 ++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/NDTensors/src/dense/tensoralgebra/contract.jl b/NDTensors/src/dense/tensoralgebra/contract.jl index 31e601ced9..66e7ab7c97 100644 --- a/NDTensors/src/dense/tensoralgebra/contract.jl +++ b/NDTensors/src/dense/tensoralgebra/contract.jl @@ -344,9 +344,8 @@ function _contract!( ) where {El,NC,NA,NB} tA = 'N' if props.permuteA - pA = NTuple{NA,Int}(props.PA) #@timeit_debug timer "_contract!: permutedims A" begin - @strided Ap = permutedims(AT, pA) + @strided Ap = permutedims(AT, props.PA) #end # @timeit AM = transpose(reshape(Ap, (props.dmid, props.dleft))) else @@ -360,9 +359,8 @@ function _contract!( tB = 'N' if props.permuteB - pB = NTuple{NB,Int}(props.PB) #@timeit_debug timer "_contract!: permutedims B" begin - @strided Bp = permutedims(BT, pB) + @strided Bp = permutedims(BT, props.PB) #end # @timeit BM = reshape(Bp, (props.dmid, props.dright)) else @@ -377,10 +375,9 @@ function _contract!( if props.permuteC # if we are computing C = α * A B + β * C # we need to make sure C is permuted to the same - # ordering as A B + # ordering as A B which is the inverse of props.PC if β ≠ 0 - pC = NTuple{NB,Int}(props.PC) - CM = reshape(permutedims(CT, pC), (props.dleft, props.dright)) + CM = reshape(permutedims(CT, invperm(props.PC)), (props.dleft, props.dright)) else # Need to copy here since we will be permuting # into C later @@ -399,11 +396,10 @@ function _contract!( mul!(CM, AM, BM, El(α), El(β)) if props.permuteC - pC = NTuple{NC,Int}(props.PC) Cr = reshape(CM, props.newCrange) # TODO: use invperm(pC) here? #@timeit_debug timer "_contract!: permutedims C" begin - @strided CT .= permutedims(Cr, pC) + @strided CT .= permutedims(Cr, props.PC) #end # @timeit end diff --git a/test/base/test_contract.jl b/test/base/test_contract.jl index a68da19f1d..a0f5ec1df9 100644 --- a/test/base/test_contract.jl +++ b/test/base/test_contract.jl @@ -245,8 +245,8 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k) A = randomITensor(T, (j, i)) B = randomITensor(T, (j, k, l, α)) C = ITensor(zero(T), (i, k, α, l)) - ITensors.contract!(C, A, B, 1.0, 0.0) - ITensors.contract!(C, A, B, 1.0, 1.0) + ITensors.contract!(C, B, A, 1.0, 0.0) + ITensors.contract!(C, B, A, 1.0, 1.0) D = A * B D .+= A * B @test C ≈ D