Skip to content

Commit

Permalink
[NDTensors] [BUG] Fix bug in in-place contract (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Jul 25, 2023
1 parent b752466 commit 163ba6e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
14 changes: 5 additions & 9 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,8 @@ function _contract!(
) where {El,NC,NA,NB}
tA = 'N'
if props.permuteA
pA = NTuple{NA,Int}(props.PA)
#@timeit_debug timer "_contract!: permutedims A" begin
@strided Ap = permutedims(AT, pA)
@strided Ap = permutedims(AT, props.PA)
#end # @timeit
AM = transpose(reshape(Ap, (props.dmid, props.dleft)))
else
Expand All @@ -360,9 +359,8 @@ function _contract!(

tB = 'N'
if props.permuteB
pB = NTuple{NB,Int}(props.PB)
#@timeit_debug timer "_contract!: permutedims B" begin
@strided Bp = permutedims(BT, pB)
@strided Bp = permutedims(BT, props.PB)
#end # @timeit
BM = reshape(Bp, (props.dmid, props.dright))
else
Expand All @@ -377,10 +375,9 @@ function _contract!(
if props.permuteC
# if we are computing C = α * A B + β * C
# we need to make sure C is permuted to the same
# ordering as A B
# ordering as A B which is the inverse of props.PC
if β 0
pC = NTuple{NB,Int}(props.PC)
CM = reshape(permutedims(CT, pC), (props.dleft, props.dright))
CM = reshape(permutedims(CT, invperm(props.PC)), (props.dleft, props.dright))
else
# Need to copy here since we will be permuting
# into C later
Expand All @@ -399,11 +396,10 @@ function _contract!(
mul!(CM, AM, BM, El(α), El(β))

if props.permuteC
pC = NTuple{NC,Int}(props.PC)
Cr = reshape(CM, props.newCrange)
# TODO: use invperm(pC) here?
#@timeit_debug timer "_contract!: permutedims C" begin
@strided CT .= permutedims(Cr, pC)
@strided CT .= permutedims(Cr, props.PC)
#end # @timeit
end

Expand Down
4 changes: 2 additions & 2 deletions test/base/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
A = randomITensor(T, (j, i))
B = randomITensor(T, (j, k, l, α))
C = ITensor(zero(T), (i, k, α, l))
ITensors.contract!(C, A, B, 1.0, 0.0)
ITensors.contract!(C, A, B, 1.0, 1.0)
ITensors.contract!(C, B, A, 1.0, 0.0)
ITensors.contract!(C, B, A, 1.0, 1.0)
D = A * B
D .+= A * B
@test C D
Expand Down

0 comments on commit 163ba6e

Please sign in to comment.