Skip to content

Commit

Permalink
[BlockSparseArrays] Redesign nested views (#1504)
Browse files Browse the repository at this point in the history
* [BlockSparseArrays] Redesign nested views

* [NDTensors] Bump to v0.3.36
  • Loading branch information
mtfishman authored Jun 25, 2024
1 parent d3afdb7 commit ab8a59e
Show file tree
Hide file tree
Showing 20 changed files with 279 additions and 324 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.35"
version = "0.3.36"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,6 @@ function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
return dual.(reverse(axes(a')))
end

# TODO: Delete this definition in favor of the one in
# GradedAxes once https://github.com/JuliaArrays/BlockArrays.jl/pull/405 is merged.
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
# to merge blocks.
function GradedAxes.blockedunitrange_getindices(
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
)
# Without converting `indices` to `Vector`,
# mapping `indices` outputs a `BlockVector`
# which is harder to reason about.
blocks = map(index -> a[index], Vector(indices))
# We pass `length.(blocks)` to `mortar` in order
# to pass block labels to the axes of the output,
# if they exist. This makes it so that
# `only(axes(a[indices])) isa `GradedUnitRange`
# if `a isa `GradedUnitRange`, for example.
# TODO: Remove `unlabel` once `BlockArray` axes
# type is generalized in BlockArrays.jl.
# TODO: Support using `BlockSparseVector`, need
# to make more `BlockSparseArray` constructors.
return BlockSparseArray(blocks, (blockedrange(length.(blocks)),))
end

# This definition is only needed since calls like
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
# returns a `BlockSparseVector` instead of a `BlockVector`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test size(b) == (4, 4, 4, 4)
@test blocksize(b) == (2, 2, 2, 2)
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
# TODO: Fix this for `BlockedArray`.
@test_broken nstored(b) == 256
@test nstored(b) == 256
# TODO: Fix this for `BlockedArray`.
@test_broken block_nstored(b) == 16
for i in 1:ndims(a)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using BlockArrays:
AbstractBlockVector,
Block,
BlockRange,
BlockedOneTo,
BlockedUnitRange,
BlockVector,
BlockSlice,
Expand All @@ -19,19 +20,6 @@ using Dictionaries: Dictionary, Indices
using ..GradedAxes: blockedunitrange_getindices
using ..SparseArrayInterface: stored_indices

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
struct GenericBlockSlice{B,T<:Integer,I<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
block::B
indices::I
end
BlockArrays.Block(bs::GenericBlockSlice{<:Block}) = bs.block
for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe_length)
@eval Base.$f(S::GenericBlockSlice) = Base.$f(S.indices)
end
Base.getindex(S::GenericBlockSlice, i::Integer) = getindex(S.indices, i)

# BlockIndices works around an issue that the indices of BlockSlice
# are restricted to AbstractUnitRange{Int}.
struct BlockIndices{B,T<:Integer,I<:AbstractVector{T}} <: AbstractVector{T}
Expand All @@ -42,6 +30,63 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
@eval Base.$f(S::BlockIndices) = Base.$f(S.indices)
end
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# It seems like this isn't handling the case where `i` is a
# subslice of a block correctly (i.e. it ignores `i.indices`).
@assert length(S.indices[Block(i)]) == length(i.indices)
return BlockSlice(S.blocks[Int(Block(i))], S.indices[Block(i)])
end
function Base.getindex(S::BlockIndices, i::BlockSlice{<:BlockRange{1}})
# TODO: Check that `i.indices` is consistent with `S.indices`.
# TODO: Turn this into a `blockedunitrange_getindices` definition.
subblocks = S.blocks[Int.(i.block)]
subindices = mortar(
map(1:length(i.block)) do I
r = blocks(i.indices)[I]
return S.indices[first(r)]:S.indices[last(r)]
end,
)
return BlockIndices(subblocks, subindices)
end

# TODO: This is type piracy. This is used in `reindex` when making
# views of blocks of sliced block arrays, for example:
# ```julia
# a = BlockSparseArray{elt}(undef, ([2, 3], [2, 3]))
# b = @view a[[Block(1)[1:1], Block(2)[1:2]], [Block(1)[1:1], Block(2)[1:2]]]
# b[Block(1, 1)]
# ```
# Without this change, BlockArrays has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
# while with this change it has the slicing behavior:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[BlockSlice(Block(2), 2:3)]
# Block(2)[1:2]
# ```
# i.e. it preserves the types of the blocks better. Upstream this fix to
# BlockArrays.jl. Also consider overloading `reindex` so that it calls
# a custom `getindex` function to avoid type piracy in the meantime.
# Also fix this in BlockArrays:
# ```julia
# julia> mortar([Block(1)[1:1], Block(2)[1:2]])[Block(2)]
# 2-element Vector{BlockIndex{1, Tuple{Int64}, Tuple{Int64}}}:
# Block(2)[1]
# Block(2)[2]
# ```
function Base.getindex(
a::BlockVector{<:BlockIndex{1},<:AbstractVector{<:BlockIndexRange{1}}},
I::BlockSlice{<:Block{1}},
)
# Check that the block slice corresponds to the correct block.
@assert I.indices == only(axes(a))[Block(I)]
return blocks(a)[Int(Block(I))]
end

# Outputs a `BlockUnitRange`.
function sub_axis(a::AbstractUnitRange, indices)
Expand Down Expand Up @@ -185,15 +230,12 @@ function blockrange(axis::AbstractUnitRange, r::AbstractVector{<:Block{1}})
return r
end

using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
function blockrange(axis::BlockedOneTo{<:Integer}, r::BlockVector{<:Integer})
return error("Slicing not implemented for range of type `$(typeof(r))`.")
end

# GenericBlockSlice works around an issue that the indices of BlockSlice
# are restricted to Int element type.
# TODO: Raise an issue/make a pull request in BlockArrays.jl.
function blockrange(axis::AbstractUnitRange, r::GenericBlockSlice)
using BlockArrays: BlockSlice
function blockrange(axis::AbstractUnitRange, r::BlockSlice)
return blockrange(axis, r.block)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include("abstractblocksparsearray/abstractblocksparsearray.jl")
include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl")
include("abstractblocksparsearray/abstractblocksparsematrix.jl")
include("abstractblocksparsearray/abstractblocksparsevector.jl")
include("abstractblocksparsearray/view.jl")
include("abstractblocksparsearray/views.jl")
include("abstractblocksparsearray/arraylayouts.jl")
include("abstractblocksparsearray/sparsearrayinterface.jl")
include("abstractblocksparsearray/broadcast.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand All @@ -25,8 +25,8 @@ function Broadcast.BroadcastStyle(
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},
Vararg{Any},
},
},
Expand All @@ -40,7 +40,7 @@ function Broadcast.BroadcastStyle(
<:Any,
<:Any,
<:AbstractBlockSparseArray,
<:Tuple{Any,BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
<:Tuple{Any,BlockSlice{<:Any,<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
},
},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ end

# TODO: Why isn't this calling `mapreduce` already?
function Base.iszero(a::BlockSparseArrayLike)
return sparse_iszero(a)
return sparse_iszero(blocks(a))
end

# TODO: Why isn't this calling `mapreduce` already?
function Base.isreal(a::BlockSparseArrayLike)
return sparse_isreal(a)
return sparse_isreal(blocks(a))
end
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ end
function SparseArrayInterface.sparse_storage(a::AbstractBlockSparseArray)
return BlockSparseStorage(a)
end

function SparseArrayInterface.nstored(a::BlockSparseArrayLike)
return sum(nstored, sparse_storage(blocks(a)); init=zero(Int))
end

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using BlockArrays: Block, BlockSlices

function blocksparse_view(a, I...)
return Base.invoke(view, Tuple{AbstractArray,Vararg{Any}}, a, I...)
end

# These definitions circumvent some generic definitions in BlockArrays.jl:
# https://github.com/JuliaArrays/BlockArrays.jl/blob/master/src/views.jl
# which don't handle subslices of blocks properly.
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Block{N}
) where {N}
return blocksparse_view(a, I)
end
function Base.view(
a::SubArray{<:Any,N,<:BlockSparseArrayLike,<:NTuple{N,BlockSlices}}, I::Vararg{Block{1},N}
) where {N}
return blocksparse_view(a, I...)
end
function Base.view(
V::SubArray{<:Any,1,<:BlockSparseArrayLike,<:Tuple{BlockSlices}}, I::Block{1}
)
return blocksparse_view(a, I)
end
Loading

2 comments on commit ab8a59e

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=NDTensors

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109769

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NDTensors-v0.3.36 -m "<description of version>" ab8a59e20e4f8ac009a8234aced2e805b9f253fd
git push origin NDTensors-v0.3.36

Please sign in to comment.