From 0ef437db27dea41c50bf63f7e5ad3ee9b10689e7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 5 Nov 2024 16:17:41 -0500 Subject: [PATCH] Add adapt support to BlockSparseArrays --- .../src/BlockSparseArraysAdaptExt.jl | 5 +++++ .../lib/BlockSparseArrays/src/BlockSparseArrays.jl | 2 ++ .../src/blocksparsearrayinterface/map.jl | 14 ++++++++++++++ 3 files changed, 21 insertions(+) create mode 100644 NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/map.jl diff --git a/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl new file mode 100644 index 0000000000..68cbf05e35 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl @@ -0,0 +1,5 @@ +module BlockSparseArraysAdaptExt +using Adapt: Adapt, adapt +using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks +Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index 3542c3e10b..af59ae7f51 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -4,6 +4,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") include("blocksparsearrayinterface/broadcast.jl") +include("blocksparsearrayinterface/map.jl") include("blocksparsearrayinterface/arraylayouts.jl") include("blocksparsearrayinterface/views.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") @@ -20,4 +21,5 @@ include("blocksparsearray/blocksparsearray.jl") include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl") include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl") +include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl") end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/map.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/map.jl new file mode 100644 index 0000000000..7ecad61a41 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/map.jl @@ -0,0 +1,14 @@ +function map_stored_blocks(f, a::AbstractArray) + # TODO: Implement this as: + # ```julia + # mapped_blocks = SparseArraysInterface.map_stored(f, blocks(a)) + # BlockSparseArray(mapped_blocks, axes(a)) + # ``` + # TODO: `block_stored_indices` should output `Indices` storing + # the stored Blocks, not a `Dictionary` from cartesian indices + # to Blocks. + bs = block_stored_indices(a) + mapped_blocks = Dictionary(bs, map(b -> f(@view(a[b])), bs)) + # TODO: Use `similartype(typeof(a), eltype(eltype(mapped_blocks)))(...)`. + return BlockSparseArray(mapped_blocks, axes(a)) +end