Skip to content

Commit

Permalink
[NDTensors] JLArrays Extension (#1508)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Sep 24, 2024
1 parent 1ef12d6 commit 29cce1c
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 58 deletions.
10 changes: 7 additions & 3 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Expand All @@ -47,6 +48,7 @@ NDTensorsAMDGPUExt = ["AMDGPU", "GPUArraysCore"]
NDTensorsCUDAExt = ["CUDA", "GPUArraysCore"]
NDTensorsGPUArraysCoreExt = "GPUArraysCore"
NDTensorsHDF5Ext = "HDF5"
NDTensorsJLArraysExt = ["GPUArraysCore", "JLArrays"]
NDTensorsMappedArraysExt = ["MappedArrays"]
NDTensorsMetalExt = ["GPUArraysCore", "Metal"]
NDTensorsOctavianExt = "Octavian"
Expand All @@ -70,15 +72,16 @@ GPUArraysCore = "0.1"
HDF5 = "0.14, 0.15, 0.16, 0.17"
HalfIntegers = "1"
InlineStrings = "1"
LinearAlgebra = "1.6"
JLArrays = "0.1"
LinearAlgebra = "<0.0.1, 1.6"
MacroTools = "0.5"
MappedArrays = "0.4"
Metal = "1"
Octavian = "0.3"
PackageExtensionCompat = "1"
Random = "1.6"
Random = "<0.0.1, 1.6"
SimpleTraits = "0.9.4"
SparseArrays = "1.6"
SparseArrays = "<0.0.1, 1.6"
SplitApplyCombine = "1.2.2"
StaticArrays = "0.12, 1.0"
Strided = "2"
Expand All @@ -95,6 +98,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
Expand Down
8 changes: 4 additions & 4 deletions NDTensors/ext/NDTensorsAMDGPUExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# TypeParameterAccessors definitions
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors, Position, default_type_parameters
using NDTensors.GPUArraysCoreExtensions: storagemode
using AMDGPU: AMDGPU, ROCArray

function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray})
return (Float64, 1, AMDGPU.Mem.HIPBuffer)
return (default_type_parameters(AbstractArray)..., AMDGPU.Mem.HIPBuffer)
end
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1)
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2)

TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(storagemode)) = Position(3)
11 changes: 3 additions & 8 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
# TypeParameterAccessors definitions
using CUDA: CUDA, CuArray
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors, Position, default_type_parameters
using NDTensors.GPUArraysCoreExtensions: storagemode

function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype))
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims))
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(storagemode))
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
return (Float64, 1, CUDA.Mem.DeviceBuffer)
return (default_type_parameters(AbstractArray)..., CUDA.Mem.DeviceBuffer)
end
7 changes: 7 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/NDTensorsJLArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module NDTensorsJLArraysExt
include("copyto.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("mul.jl")
include("permutedims.jl")
end
30 changes: 30 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/copyto.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using JLArrays: JLArray
using NDTensors.Expose: Exposed, expose, unexpose
using LinearAlgebra: Adjoint

# Same definition as `CuArray`.
function Base.copy(src::Exposed{<:JLArray,<:Base.ReshapedArray})
return reshape(copy(parent(src)), size(unexpose(src)))
end

function Base.copy(
src::Exposed{
<:JLArray,<:SubArray{<:Any,<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}}
},
)
return copy(@view copy(expose(parent(src)))[parentindices(unexpose(src))...])
end

# Catches a bug in `copyto!` in CUDA backend.
function Base.copyto!(dest::Exposed{<:JLArray}, src::Exposed{<:JLArray,<:SubArray})
copyto!(dest, expose(copy(src)))
return unexpose(dest)
end

# Catches a bug in `copyto!` in JLArray backend.
function Base.copyto!(
dest::Exposed{<:JLArray}, src::Exposed{<:JLArray,<:Base.ReshapedArray}
)
copyto!(dest, expose(parent(src)))
return unexpose(dest)
end
19 changes: 19 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using JLArrays: JLArray
using GPUArraysCore: @allowscalar
using NDTensors: NDTensors
using NDTensors.Expose: Exposed, expose, unexpose

function Base.getindex(E::Exposed{<:JLArray})
return @allowscalar unexpose(E)[]
end

function Base.setindex!(E::Exposed{<:JLArray}, x::Number)
@allowscalar unexpose(E)[] = x
return unexpose(E)
end

function Base.getindex(E::Exposed{<:JLArray,<:Adjoint}, i, j)
return (expose(parent(E))[j, i])'
end

Base.any(f, E::Exposed{<:JLArray,<:NDTensors.Tensor}) = any(f, data(unexpose(E)))
40 changes: 40 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Adapt: adapt
using JLArrays: JLArray, JLMatrix
using LinearAlgebra: LinearAlgebra, Hermitian, Symmetric, qr, eigen
using NDTensors: NDTensors
using NDTensors.Expose: Expose, expose, qr, qr_positive, ql, ql_positive
using NDTensors.GPUArraysCoreExtensions: cpu
using NDTensors.TypeParameterAccessors: unwrap_array_type

## TODO this function exists because of the same issue below. when
## that issue is resolved we can rely on the abstractarray version of
## this operation.
function Expose.qr(A::Exposed{<:JLArray})
Q, L = qr(unexpose(A))
return adapt(unwrap_array_type(A), Matrix(Q)), adapt(unwrap_array_type(A), L)
end
## TODO this should work using a JLArray but there is an error converting the Q from its packed QR from
## back into a JLArray see https://github.com/JuliaGPU/GPUArrays.jl/issues/545. To fix call cpu for now
function Expose.qr_positive(A::Exposed{<:JLArray})
Q, L = qr_positive(expose(cpu(A)))
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
end

function Expose.ql(A::Exposed{<:JLMatrix})
Q, L = ql(expose(cpu(A)))
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
end
function Expose.ql_positive(A::Exposed{<:JLMatrix})
Q, L = ql_positive(expose(cpu(A)))
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
end

function LinearAlgebra.eigen(A::Exposed{<:JLMatrix,<:Symmetric})
q, l = (eigen(expose(cpu(A))))
return adapt.(unwrap_array_type(A), (q, l))
end

function LinearAlgebra.eigen(A::Exposed{<:JLMatrix,<:Hermitian})
q, l = (eigen(expose(Hermitian(cpu(unexpose(A).data)))))
return adapt.(JLArray, (q, l))
end
43 changes: 43 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/mul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using JLArrays: JLArray
using LinearAlgebra: LinearAlgebra, mul!, transpose
using NDTensors.Expose: Exposed, expose, unexpose

function LinearAlgebra.mul!(
CM::Exposed{<:JLArray,<:LinearAlgebra.Transpose},
AM::Exposed{<:JLArray},
BM::Exposed{<:JLArray},
α,
β,
)
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
return unexpose(CM)
end

function LinearAlgebra.mul!(
CM::Exposed{<:JLArray,<:LinearAlgebra.Adjoint},
AM::Exposed{<:JLArray},
BM::Exposed{<:JLArray},
α,
β,
)
mul!(CM', BM', AM', α, β)
return unexpose(CM)
end

## Fix issue in JLArrays.jl where it cannot distinguish Transpose{Reshape{Adjoint{JLArray}}}
## as a JLArray and calls generic matmul
function LinearAlgebra.mul!(
CM::Exposed{<:JLArray},
AM::Exposed{<:JLArray},
BM::Exposed{
<:JLArray,
<:LinearAlgebra.Transpose{
<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Adjoint}
},
},
α,
β,
)
mul!(CM, AM, expose(transpose(copy(expose(parent(BM))))), α, β)
return unexpose(CM)
end
24 changes: 24 additions & 0 deletions NDTensors/ext/NDTensorsJLArraysExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using JLArrays: JLArray
using LinearAlgebra: Adjoint
using NDTensors.Expose: Exposed, expose, unexpose

function Base.permutedims!(
Edest::Exposed{<:JLArray,<:Base.ReshapedArray}, Esrc::Exposed{<:JLArray}, perm
)
Aperm = permutedims(Esrc, perm)
copyto!(expose(parent(Edest)), expose(Aperm))
return unexpose(Edest)
end

## Found an issue in CUDA where if Edest is a reshaped{<:Adjoint}
## .= can fail. So instead force Esrc into the shape of parent(Edest)
function Base.permutedims!(
Edest::Exposed{<:JLArray,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}},
Esrc::Exposed{<:JLArray},
perm,
f,
)
Aperm = reshape(permutedims(Esrc, perm), size(parent(Edest)))
parent(Edest) .= f.(parent(Edest), Aperm)
return unexpose(Edest)
end
7 changes: 0 additions & 7 deletions NDTensors/ext/NDTensorsMetalExt/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@ using Metal: Metal, MtlArray
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.GPUArraysCoreExtensions: storagemode

## TODO remove TypeParameterAccessors when SetParameters is removed
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype))
return Position(1)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims))
return Position(2)
end
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(storagemode))
return Position(3)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ end
position(type::Type{<:AbstractArray}, ::typeof(eltype)) = Position(1)
position(type::Type{<:AbstractArray}, ::typeof(ndims)) = Position(2)

default_type_parameters(::Type{<:AbstractArray}) = (Float64, 1)

for wrapper in [:PermutedDimsArray, :(Base.ReshapedArray), :SubArray]
@eval begin
position(type::Type{<:$wrapper}, ::typeof(eltype)) = Position(1)
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
position(::Type{<:Array}, ::typeof(eltype)) = Position(1)
position(::Type{<:Array}, ::typeof(ndims)) = Position(2)

default_type_parameters(::Type{<:Array}) = (Float64, 1)
23 changes: 17 additions & 6 deletions NDTensors/test/NDTensorsTestUtils/device_list.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
using NDTensors: NDTensors
using Pkg: Pkg
using NDTensors: NDTensors

if "cuda" in ARGS || "all" in ARGS
## Right now adding CUDA during Pkg.test results in a
## compat issues. I am adding it back to test/Project.toml
# Pkg.add("CUDA")
using CUDA
Pkg.add("CUDA")
using CUDA: CUDA
end
if "rocm" in ARGS || "all" in ARGS
## Warning AMDGPU does not work in Julia versions below 1.8
Pkg.add("AMDGPU")
using AMDGPU
using AMDGPU: AMDGPU
end
if "metal" in ARGS || "all" in ARGS
## Warning Metal does not work in Julia versions below 1.8
Pkg.add("Metal")
using Metal
using Metal: Metal
end
if "cutensor" in ARGS || "all" in ARGS
Pkg.add("CUDA")
Pkg.add("cuTENSOR")
using CUDA, cuTENSOR
using CUDA: CUDA
using cuTENSOR: cuTENSOR
end

using JLArrays: JLArrays, jl

function devices_list(test_args)
devs = Vector{Function}(undef, 0)
if isempty(test_args) || "base" in test_args
push!(devs, NDTensors.cpu)
## Skip jl on lower versions of Julia for now
## all linear algebra is failing on Julia 1.6 with JLArrays
if VERSION > v"1.7"
push!(devs, jl)
end
end

if "cuda" in test_args || "cutensor" in test_args || "all" in test_args
Expand All @@ -44,5 +54,6 @@ function devices_list(test_args)
if "metal" in test_args || "all" in test_args
push!(devs, NDTensors.MetalExtensions.mtl)
end

return devs
end
5 changes: 3 additions & 2 deletions NDTensors/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Expand All @@ -24,10 +24,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Metal = "1.1.0"
cuTENSOR = "2.0"
Metal = "1.1.0"

[extras]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Loading

0 comments on commit 29cce1c

Please sign in to comment.