From 995b6dce277b9304254c3ba7e0585b5e459651d4 Mon Sep 17 00:00:00 2001 From: Karl Pierce Date: Tue, 3 Sep 2024 20:25:50 -0400 Subject: [PATCH] [NDTensors] Use TensorOperations.jl v5 in tests (#1483) --- .../src/lib/TensorAlgebra/test/test_basics.jl | 141 +++++++++--------- .../test/NDTensorsTestUtils/device_list.jl | 3 - 2 files changed, 68 insertions(+), 76 deletions(-) diff --git a/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl b/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl index bafabdf0b9..95576a8bf5 100644 --- a/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl +++ b/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl @@ -122,83 +122,78 @@ end @test eltype(a_split) === elt @test a_split ≈ reshape(a, (2, 3, 20)) end - ## Right now TensorOperations version is downgraded when using cuTENSOR to `v0.7` we - ## are waiting for TensorOperations to support the breaking changes in cuTENSOR 2.x - if !("cutensor" ∈ ARGS) - using TensorOperations: TensorOperations - @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts - dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) - labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) - for (d1s, d2s, d_dests) in ( - ((1, 2), (1, 2), ()), - ((1, 2), (2, 1), ()), - ((1, 2), (2, 1, 3), (3,)), - ((1, 2, 3), (2, 1), (3,)), - ((1, 2), (2, 3), (1, 3)), - ((1, 2), (2, 3), (3, 1)), - ((2, 1), (2, 3), (3, 1)), - ((1, 2, 3), (2, 3, 4), (1, 4)), - ((1, 2, 3), (2, 3, 4), (4, 1)), - ((3, 2, 1), (4, 2, 3), (4, 1)), - ((1, 2, 3), (3, 4), (1, 2, 4)), - ((1, 2, 3), (3, 4), (4, 1, 2)), - ((1, 2, 3), (3, 4), (2, 4, 1)), - ((3, 1, 2), (3, 4), (2, 4, 1)), - ((3, 2, 1), (4, 3), (2, 4, 1)), - ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)), - ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)), - ) - a1 = randn(elt1, map(i -> dims[i], d1s)) - labels1 = map(i -> labels[i], d1s) - a2 = randn(elt2, map(i -> dims[i], d2s)) - labels2 = map(i -> labels[i], d2s) - labels_dest = map(i -> labels[i], d_dests) + using TensorOperations: TensorOperations + @testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts + dims = (2, 3, 4, 5, 6, 7, 8, 9, 10) + labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i) + for (d1s, d2s, d_dests) in ( + ((1, 2), (1, 2), ()), + ((1, 2), (2, 1), ()), + ((1, 2), (2, 1, 3), (3,)), + ((1, 2, 3), (2, 1), (3,)), + ((1, 2), (2, 3), (1, 3)), + ((1, 2), (2, 3), (3, 1)), + ((2, 1), (2, 3), (3, 1)), + ((1, 2, 3), (2, 3, 4), (1, 4)), + ((1, 2, 3), (2, 3, 4), (4, 1)), + ((3, 2, 1), (4, 2, 3), (4, 1)), + ((1, 2, 3), (3, 4), (1, 2, 4)), + ((1, 2, 3), (3, 4), (4, 1, 2)), + ((1, 2, 3), (3, 4), (2, 4, 1)), + ((3, 1, 2), (3, 4), (2, 4, 1)), + ((3, 2, 1), (4, 3), (2, 4, 1)), + ((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)), + ((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)), + ) + a1 = randn(elt1, map(i -> dims[i], d1s)) + labels1 = map(i -> labels[i], d1s) + a2 = randn(elt2, map(i -> dims[i], d2s)) + labels2 = map(i -> labels[i], d2s) + labels_dest = map(i -> labels[i], d_dests) - # Don't specify destination labels - a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest′, a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations + # Don't specify destination labels + a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2) + a_dest_tensoroperations = TensorOperations.tensorcontract( + labels_dest′, a1, labels1, a2, labels2 + ) + @test a_dest ≈ a_dest_tensoroperations - # Specify destination labels - a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest, a1, labels1, a2, labels2 - ) - @test a_dest ≈ a_dest_tensoroperations + # Specify destination labels + a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2) + a_dest_tensoroperations = TensorOperations.tensorcontract( + labels_dest, a1, labels1, a2, labels2 + ) + @test a_dest ≈ a_dest_tensoroperations - # Specify α and β - elt_dest = promote_type(elt1, elt2) - # TODO: Using random `α`, `β` causing - # random test failures, investigate why. - α = elt_dest(1.2) # randn(elt_dest) - β = elt_dest(2.4) # randn(elt_dest) - a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) - a_dest = copy(a_dest_init) - TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) - a_dest_tensoroperations = TensorOperations.tensorcontract( - labels_dest, a1, labels1, a2, labels2 - ) - ## Here we loosened the tolerance because of some floating point roundoff issue. - ## with Float32 numbers - @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol = - 50 * default_rtol(elt_dest) - end + # Specify α and β + elt_dest = promote_type(elt1, elt2) + # TODO: Using random `α`, `β` causing + # random test failures, investigate why. + α = elt_dest(1.2) # randn(elt_dest) + β = elt_dest(2.4) # randn(elt_dest) + a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests)) + a_dest = copy(a_dest_init) + TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β) + a_dest_tensoroperations = TensorOperations.tensorcontract( + labels_dest, a1, labels1, a2, labels2 + ) + ## Here we loosened the tolerance because of some floating point roundoff issue. + ## with Float32 numbers + @test a_dest ≈ α * a_dest_tensoroperations + β * a_dest_init rtol = + 50 * default_rtol(elt_dest) end end - @testset "qr (eltype=$elt)" for elt in elts - a = randn(elt, 5, 4, 3, 2) - labels_a = (:a, :b, :c, :d) - labels_q = (:b, :a) - labels_r = (:d, :c) - q, r = qr(a, labels_a, labels_q, labels_r) - label_qr = :qr - a′ = TensorAlgebra.contract( - labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...) - ) - @test a ≈ a′ - end end - +@testset "qr (eltype=$elt)" for elt in elts + a = randn(elt, 5, 4, 3, 2) + labels_a = (:a, :b, :c, :d) + labels_q = (:b, :a) + labels_r = (:d, :c) + q, r = qr(a, labels_a, labels_q, labels_r) + label_qr = :qr + a′ = TensorAlgebra.contract( + labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...) + ) + @test a ≈ a′ +end end diff --git a/NDTensors/test/NDTensorsTestUtils/device_list.jl b/NDTensors/test/NDTensorsTestUtils/device_list.jl index 56c44d5fc4..1220ec42a9 100644 --- a/NDTensors/test/NDTensorsTestUtils/device_list.jl +++ b/NDTensors/test/NDTensorsTestUtils/device_list.jl @@ -17,9 +17,6 @@ if "metal" in ARGS || "all" in ARGS using Metal end if "cutensor" in ARGS || "all" in ARGS - if in("TensorOperations", map(v -> v.name, values(Pkg.dependencies()))) - Pkg.rm("TensorOperations") - end Pkg.add("cuTENSOR") using CUDA, cuTENSOR end