Skip to content

Commit

Permalink
[ITensors] Fix broken broadcast operation on GPU (#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Oct 13, 2024
1 parent eba5e17 commit 98a7724
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
14 changes: 14 additions & 0 deletions NDTensors/test/test_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,20 @@ NDTensors.dim(i::MyInd) = i.dim
@test A[2, 2] == Aview[1, 1]
end

## Testing A .= α .* B .+ β .* A
C = copy(A)
@allowscalar fill!(B, zero(elt))
β = elt(2)
α = elt(1)
permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b)))
@allowscalar 2 .* C == A
randn!(B)
C = copy(A)
A = permutedims!!(A, B, (1, 2), (a, b) -> +(*(β, a), *(α, b)))
@allowscalar for i in 1:3, j in 1:4
@test A[i, j] == α * B[i, j] + β * C[i, j]
end

## add elt around 2.0 to preserve the eltype of A.
@test data(A * elt(2.0)) == data(elt(2.0) * A)

Expand Down
11 changes: 10 additions & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,13 @@ end
# C .= β .* C .+ α .* A .* B
#

struct axpby{Alpha,Beta} <: Function
alpha::Alpha
beta::Beta
end

(f::axpby)(y, x) = x * f.alpha + y * f.beta

## TODO this code doesn't actually get called
function Base.copyto!(
T::ITensor,
Expand All @@ -414,7 +421,9 @@ function Base.copyto!(
A, C = C, A
end
if !isnothing(A) && !isnothing(C) && !isnothing(α) && !isnothing(β)
map!((r, t) -> β * r + α * t, T, T, A)
# The following fails to compile on some GPU backends.
# map!((r, t) -> β * r + α * t, T, T, A)
map!(axpby(α, β), T, T, A)
else
bc_bc_α = find_type(Broadcasted, bc_α.args)
if isnothing(α)
Expand Down

0 comments on commit 98a7724

Please sign in to comment.