Skip to content

Commit

Permalink
[BlockSparseArrays] Initial support for more general blocks, such as …
Browse files Browse the repository at this point in the history
…GPU blocks (#1560)
  • Loading branch information
mtfishman authored Nov 8, 2024
1 parent 4299ab4 commit d1547b4
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 120 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module BlockSparseArraysAdaptExt
using Adapt: Adapt, adapt
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
end
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ArrayLayouts: ArrayLayouts, MemoryLayout, sub_materialize
using BlockArrays:
BlockArrays,
AbstractBlockArray,
Expand Down Expand Up @@ -537,6 +538,20 @@ function SparseArrayInterface.nstored(a::BlockView)
return 0
end

## # Allow more fine-grained control:
## function ArrayLayouts.sub_materialize(layout, a::BlockView, ax)
## return blocks(a.array)[Int.(a.block)...]
## end
## function ArrayLayouts.sub_materialize(layout, a::BlockView)
## return sub_materialize(layout, a, axes(a))
## end
## function ArrayLayouts.sub_materialize(a::BlockView)
## return sub_materialize(MemoryLayout(a), a)
## end
function ArrayLayouts.sub_materialize(a::BlockView)
return blocks(a.array)[Int.(a.block)...]
end

function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
return view!(a, Tuple(index)...)
end
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/map.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
Expand All @@ -20,4 +21,5 @@ include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
end
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ abstract type AbstractBlockSparseArray{T,N} <: AbstractBlockArray{T,N} end

Base.axes(::AbstractBlockSparseArray) = error("Not implemented")

blockstype(::Type{<:AbstractBlockSparseArray}) = error("Not implemented")
# TODO: Add some logic to unwrapping wrapped arrays.
# TODO: Decide what a good default is.
blockstype(arraytype::Type{<:AbstractBlockSparseArray}) = SparseArrayDOK{AbstractArray}
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T}}) where {T}
return SparseArrayDOK{AbstractArray{T}}
end
function blockstype(arraytype::Type{<:AbstractBlockSparseArray{T,N}}) where {T,N}
return SparseArrayDOK{AbstractArray{T,N},N}
end

## # Specialized in order to fix ambiguity error with `BlockArrays`.
function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ArrayLayouts: ArrayLayouts, MemoryLayout, MulAdd
using BlockArrays: BlockLayout
using ..SparseArrayInterface: SparseLayout
using ..TypeParameterAccessors: similartype

function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
outer_layout = typeof(MemoryLayout(blockstype(arraytype)))
Expand All @@ -9,15 +10,22 @@ function ArrayLayouts.MemoryLayout(arraytype::Type{<:BlockSparseArrayLike})
end

function Base.similar(
::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}}, elt::Type, axes
)
return similar(BlockSparseArray{elt}, axes)
mul::MulAdd{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout},<:Any,<:Any,A,B},
elt::Type,
axes,
) where {A,B}
# TODO: Check that this equals `similartype(blocktype(B), elt, axes)`,
# or maybe promote them?
output_blocktype = similartype(blocktype(A), elt, axes)
return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes)
end

# Materialize a SubArray view.
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
# TODO: Make more generic for GPU.
a_dest = BlockSparseArray{eltype(a)}(axes)
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
# TODO: Use `similar`?
blocktype_a = blocktype(parent(a))
a_dest = BlockSparseArray{eltype(a),length(axes),blocktype_a}(axes)
a_dest .= a
return a_dest
end
Expand All @@ -26,8 +34,7 @@ end
function ArrayLayouts.sub_materialize(
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
)
# TODO: Make more generic for GPU.
a_dest = Array{eltype(a)}(undef, length.(axes))
a_dest = blocktype(a)(undef, length.(axes))
a_dest .= a
return a_dest
end
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using BlockArrays:
mortar,
unblock
using SplitApplyCombine: groupcount
using ..TypeParameterAccessors: similartype

const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
Expand Down Expand Up @@ -187,28 +188,29 @@ function Base.similar(
return similar(arraytype, eltype(arraytype), axes)
end

function blocksparse_similar(a, elt::Type, axes::Tuple)
return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}(
undef, axes
)
end

# Needed by `BlockArrays` matrix multiplication interface
# TODO: Define a `blocksparse_similar` function.
function Base.similar(
arraytype::Type{<:BlockSparseArrayLike},
elt::Type,
axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(arraytype, elt, axes)
end

# TODO: Define a `blocksparse_similar` function.
function Base.similar(
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end

# TODO: Define a `blocksparse_similar` function.
# Fixes ambiguity error with `BlockArrays`.
function Base.similar(
a::BlockSparseArrayLike,
Expand All @@ -217,21 +219,16 @@ function Base.similar(
AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}}
},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end

# TODO: Define a `blocksparse_similar` function.
# Fixes ambiguity error with `OffsetArrays`.
function Base.similar(
a::BlockSparseArrayLike,
elt::Type,
axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end

# Fixes ambiguity error with `BlockArrays`.
Expand All @@ -240,9 +237,7 @@ function Base.similar(
elt::Type,
axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}},
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end

# Fixes ambiguity errors with BlockArrays.
Expand All @@ -255,15 +250,12 @@ function Base.similar(
Vararg{AbstractUnitRange{<:Integer}},
},
)
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end

# TODO: Define a `blocksparse_similar` function.
# Fixes ambiguity error with `StaticArrays`.
function Base.similar(
a::BlockSparseArrayLike, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}}
)
# TODO: Make generic for GPU, maybe using `blocktype`.
# TODO: For non-block axes this should output `Array`.
return BlockSparseArray{elt}(undef, axes)
return blocksparse_similar(a, elt, axes)
end
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ function BlockSparseArray(
return BlockSparseArray(Dictionary(block_indices, block_data), axes)
end

function BlockSparseArray{T,N,A,Blocks}(
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
return BlockSparseArray{T,N,A,Blocks,typeof(axes)}(blocks, axes)
end

function BlockSparseArray{T,N,A}(
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A,typeof(blocks)}(blocks, axes)
end

function BlockSparseArray{T,N}(
blocks::AbstractArray{<:AbstractArray{T,N},N}, axes::Tuple{Vararg{AbstractUnitRange,N}}
) where {T,N}
Expand All @@ -49,9 +61,15 @@ function BlockSparseArray{T,N}(
return BlockSparseArray{T,N}(blocks, axes)
end

function BlockSparseArray{T,N,A}(
axes::Tuple{Vararg{AbstractUnitRange,N}}
) where {T,N,A<:AbstractArray{T,N}}
blocks = default_blocks(A, axes)
return BlockSparseArray{T,N,A}(blocks, axes)
end

function BlockSparseArray{T,N}(axes::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N}
blocks = default_blocks(T, axes)
return BlockSparseArray{T,N}(blocks, axes)
return BlockSparseArray{T,N,default_arraytype(T, axes)}(axes)
end

function BlockSparseArray{T,N}(dims::Tuple{Vararg{Vector{Int},N}}) where {T,N}
Expand All @@ -74,6 +92,12 @@ function BlockSparseArray{T}(axes::Vararg{AbstractUnitRange}) where {T}
return BlockSparseArray{T}(axes)
end

function BlockSparseArray{T,N,A}(
::UndefInitializer, dims::Tuple
) where {T,N,A<:AbstractArray{T,N}}
return BlockSparseArray{T,N,A}(dims)
end

# undef
function BlockSparseArray{T,N}(
::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange,N}}
Expand Down Expand Up @@ -109,7 +133,23 @@ Base.axes(a::BlockSparseArray) = a.axes
blocksparse_blocks(a::BlockSparseArray) = a.blocks

# TODO: Use `TypeParameterAccessors`.
blockstype(::Type{<:BlockSparseArray{<:Any,<:Any,<:Any,B}}) where {B} = B
function blockstype(
arraytype::Type{<:BlockSparseArray{T,N,A,Blocks}}
) where {T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N}}
return Blocks
end
function blockstype(
arraytype::Type{<:BlockSparseArray{T,N,A}}
) where {T,N,A<:AbstractArray{T,N}}
return SparseArrayDOK{A,N}
end
function blockstype(arraytype::Type{<:BlockSparseArray{T,N}}) where {T,N}
return SparseArrayDOK{AbstractArray{T,N},N}
end
function blockstype(arraytype::Type{<:BlockSparseArray{T}}) where {T}
return SparseArrayDOK{AbstractArray{T}}
end
blockstype(arraytype::Type{<:BlockSparseArray}) = SparseArrayDOK{AbstractArray}

## # Base interface
## function Base.similar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ function default_arraytype(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
return Array{elt,length(axes)}
end

function default_blocks(elt::Type, axes::Tuple{Vararg{AbstractUnitRange}})
block_data = Dictionary{Block{length(axes),Int},default_arraytype(elt, axes)}()
function default_blocks(blocktype::Type, axes::Tuple{Vararg{AbstractUnitRange}})
block_data = Dictionary{Block{length(axes),Int},blocktype}()
return default_blocks(block_data, axes)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ _getindices(i::CartesianIndex, indices) = CartesianIndex(_getindices(Tuple(i), i

# Represents the array of arrays of a `PermutedDimsArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `PermutedDimsArray`.
struct SparsePermutedDimsArrayBlocks{T,N,Array<:PermutedDimsArray{T,N}} <:
AbstractSparseArray{T,N}
struct SparsePermutedDimsArrayBlocks{
T,N,BlockType<:AbstractArray{T,N},Array<:PermutedDimsArray{T,N}
} <: AbstractSparseArray{BlockType,N}
array::Array
end
function blocksparse_blocks(a::PermutedDimsArray)
return SparsePermutedDimsArrayBlocks(a)
return SparsePermutedDimsArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparsePermutedDimsArrayBlocks)
return _getindices(size(blocks(parent(a.array))), _perm(a.array))
Expand Down Expand Up @@ -158,11 +159,12 @@ reverse_index(index::CartesianIndex) = CartesianIndex(reverse(Tuple(index)))

# Represents the array of arrays of a `Transpose`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Transpose`.
struct SparseTransposeBlocks{T,Array<:Transpose{T}} <: AbstractSparseMatrix{T}
struct SparseTransposeBlocks{T,BlockType<:AbstractMatrix{T},Array<:Transpose{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Transpose)
return SparseTransposeBlocks(a)
return SparseTransposeBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseTransposeBlocks)
return reverse(size(blocks(parent(a.array))))
Expand Down Expand Up @@ -192,11 +194,12 @@ end

# Represents the array of arrays of a `Adjoint`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `Adjoint`.
struct SparseAdjointBlocks{T,Array<:Adjoint{T}} <: AbstractSparseMatrix{T}
struct SparseAdjointBlocks{T,BlockType<:AbstractMatrix{T},Array<:Adjoint{T}} <:
AbstractSparseMatrix{BlockType}
array::Array
end
function blocksparse_blocks(a::Adjoint)
return SparseAdjointBlocks(a)
return SparseAdjointBlocks{eltype(a),blocktype(parent(a)),typeof(a)}(a)
end
function Base.size(a::SparseAdjointBlocks)
return reverse(size(blocks(parent(a.array))))
Expand Down Expand Up @@ -230,9 +233,13 @@ end

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N}
struct SparseSubArrayBlocks{T,N,BlockType<:AbstractArray{T,N},Array<:SubArray{T,N}} <:
AbstractSparseArray{BlockType,N}
array::Array
end
function blocksparse_blocks(a::SubArray)
return SparseSubArrayBlocks{eltype(a),ndims(a),blocktype(parent(a)),typeof(a)}(a)
end
# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
function blockrange(a::SparseSubArrayBlocks)
blockranges = blockrange.(axes(parent(a.array)), a.array.indices)
Expand Down Expand Up @@ -291,8 +298,10 @@ function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
return map(I -> a[I], stored_indices(a))
end

function blocksparse_blocks(a::SubArray)
return SparseSubArrayBlocks(a)
function SparseArrayInterface.getindex_zero_function(a::SparseSubArrayBlocks)
# TODO: Base it off of `getindex_zero_function(blocks(parent(a.array))`, but replace the
# axes with `axes(a.array)`.
return BlockZero(axes(a.array))
end

to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ end
function (f::BlockZero)(arraytype::Type{<:AbstractArray}, I)
# TODO: Make sure this works for sparse or block sparse blocks, immutable
# blocks, diagonal blocks, etc.!
return fill!(arraytype(undef, block_size(f.axes, Block(Tuple(I)))), false)
blck_size = block_size(f.axes, Block(Tuple(I)))
blck_type = similartype(arraytype, blck_size)
return fill!(blck_type(undef, blck_size), false)
end

# Fallback so that `SparseArray` with scalar elements works.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
function map_stored_blocks(f, a::AbstractArray)
# TODO: Implement this as:
# ```julia
# mapped_blocks = SparseArraysInterface.map_stored(f, blocks(a))
# BlockSparseArray(mapped_blocks, axes(a))
# ```
# TODO: `block_stored_indices` should output `Indices` storing
# the stored Blocks, not a `Dictionary` from cartesian indices
# to Blocks.
bs = collect(block_stored_indices(a))
ds = map(b -> f(@view(a[b])), bs)
# We manually specify the block type using `Base.promote_op`
# since `a[b]` may not be inferrable. For example, if `blocktype(a)`
# is `Diagonal{Float64,Vector{Float64}}`, the non-stored blocks are `Matrix{Float64}`
# since they can't necessarily by `Diagonal` if there are rectangular blocks.
mapped_blocks = Dictionary{eltype(bs),eltype(ds)}(bs, ds)
# TODO: Use `similartype(typeof(a), eltype(eltype(mapped_blocks)))(...)`.
return BlockSparseArray(mapped_blocks, axes(a))
end
Loading

0 comments on commit d1547b4

Please sign in to comment.