Skip to content

Commit

Permalink
Add adapt support to BlockSparseArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 5, 2024
1 parent 0fbac75 commit 0ef437d
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module BlockSparseArraysAdaptExt
using Adapt: Adapt, adapt
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
end
2 changes: 2 additions & 0 deletions NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
include("blocksparsearrayinterface/linearalgebra.jl")
include("blocksparsearrayinterface/blockzero.jl")
include("blocksparsearrayinterface/broadcast.jl")
include("blocksparsearrayinterface/map.jl")
include("blocksparsearrayinterface/arraylayouts.jl")
include("blocksparsearrayinterface/views.jl")
include("abstractblocksparsearray/abstractblocksparsearray.jl")
Expand All @@ -20,4 +21,5 @@ include("blocksparsearray/blocksparsearray.jl")
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function map_stored_blocks(f, a::AbstractArray)
# TODO: Implement this as:
# ```julia
# mapped_blocks = SparseArraysInterface.map_stored(f, blocks(a))
# BlockSparseArray(mapped_blocks, axes(a))
# ```
# TODO: `block_stored_indices` should output `Indices` storing
# the stored Blocks, not a `Dictionary` from cartesian indices
# to Blocks.
bs = block_stored_indices(a)
mapped_blocks = Dictionary(bs, map(b -> f(@view(a[b])), bs))
# TODO: Use `similartype(typeof(a), eltype(eltype(mapped_blocks)))(...)`.
return BlockSparseArray(mapped_blocks, axes(a))
end

0 comments on commit 0ef437d

Please sign in to comment.