Skip to content

Commit

Permalink
Move default_typeparameters to AbstractArray
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Sep 22, 2024
1 parent 638377c commit 4208487
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 13 deletions.
6 changes: 4 additions & 2 deletions NDTensors/ext/NDTensorsAMDGPUExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# TypeParameterAccessors definitions
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors, Position, default_type_parameters
using NDTensors.GPUArraysCoreExtensions: storagemode
using AMDGPU: AMDGPU, ROCArray
using GPUArraysCore: AbstractGPUArray

function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray})
return (Float64, 1, AMDGPU.Mem.HIPBuffer)
return (default_type_parameters(AbstractGPUArray)..., AMDGPU.Mem.HIPBuffer)
end

TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(storagemode)) = Position(3)
6 changes: 4 additions & 2 deletions NDTensors/ext/NDTensorsCUDAExt/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# TypeParameterAccessors definitions
using CUDA: CUDA, CuArray
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
using NDTensors.TypeParameterAccessors:
TypeParameterAccessors, Position, default_type_parameters
using NDTensors.GPUArraysCoreExtensions: storagemode
using GPUArraysCore: AbstractGPUArray

function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(storagemode))
return Position(3)
end

function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
return (Float64, 1, CUDA.Mem.DeviceBuffer)
return (default_type_parameters(AbstractGPUArray)..., CUDA.Mem.DeviceBuffer)
end
7 changes: 0 additions & 7 deletions NDTensors/ext/NDTensorsJLArraysExt/set_types.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ end
position(type::Type{<:AbstractArray}, ::typeof(eltype)) = Position(1)
position(type::Type{<:AbstractArray}, ::typeof(ndims)) = Position(2)

default_type_parameters(::Type{<:AbstractArray}) = (Float64, 1)

for wrapper in [:PermutedDimsArray, :(Base.ReshapedArray), :SubArray]
@eval begin
position(type::Type{<:$wrapper}, ::typeof(eltype)) = Position(1)
Expand Down
2 changes: 0 additions & 2 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
position(::Type{<:Array}, ::typeof(eltype)) = Position(1)
position(::Type{<:Array}, ::typeof(ndims)) = Position(2)

default_type_parameters(::Type{<:Array}) = (Float64, 1)

0 comments on commit 4208487

Please sign in to comment.