diff --git a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl index df65ed6155..a16da982f9 100644 --- a/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl +++ b/NDTensors/src/lib/GradedAxes/src/gradedunitrangedual.jl @@ -57,62 +57,41 @@ end function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::Vector{<:BlockIndexRange{1}} ) - a_indices = getindex(nondual(a), indices) - v = mortar(dual.(blocks(a_indices))) - # flip v to stay consistent with other cases where axes(v) are used - return flip_blockvector(v) + # dual v axes to stay consistent with other cases where axes(v) are used + return dual_axes(blockedunitrange_getindices(nondual(a), indices)) end function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexRange{1}}}, ) - v = mortar(map(b -> a[b], blocks(indices))) - # GradedOneTo appears in mortar - # flip v axis to preserve dual information + # dual v axis to preserve dual information # axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)[1:1]])) - return flip_blockvector(v) + return dual_axes(blockedunitrange_getindices(nondual(a), indices)) end function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}} ) - # Without converting `indices` to `Vector`, - # mapping `indices` outputs a `BlockVector` - # which is harder to reason about. - vblocks = map(index -> a[index], Vector(indices)) - # We pass `length.(blocks)` to `mortar` in order - # to pass block labels to the axes of the output, - # if they exist. This makes it so that - # `only(axes(a[indices])) isa `GradedUnitRange` - # if `a isa `GradedUnitRange`, for example. - - v = mortar(vblocks, length.(vblocks)) - # GradedOneTo appears in mortar - # flip v axis to preserve dual information + # dual v axis to preserve dual information # axes(v) will appear in axes(view(::BlockSparseArray, [Block(1)])) - return flip_blockvector(v) + return dual_axes(blockedunitrange_getindices(nondual(a), indices)) end # Fixes ambiguity error. -# TODO: Write this in terms of `blockedunitrange_getindices(dual(a), indices)`. function blockedunitrange_getindices( a::GradedUnitRangeDual, indices::AbstractBlockVector{<:Block{1}} ) - blks = map(bs -> mortar(map(b -> a[b], bs)), blocks(indices)) - # We pass `length.(blks)` to `mortar` in order - # to pass block labels to the axes of the output, - # if they exist. This makes it so that - # `only(axes(a[indices])) isa `GradedUnitRange` - # if `a isa `GradedUnitRange`, for example. - v = mortar(blks, labelled_length.(blks)) - return flip_blockvector(v) -end - -function flip_blockvector(v::BlockVector) - block_axes = flip.(axes(v)) - flipped = mortar(vec.(blocks(v)), block_axes) - return flipped + v = blockedunitrange_getindices(nondual(a), indices) + # v elements are not dualled by dual_axes due to different structure. + # take element dual here. + return dual_axes(dual.(v)) +end + +function dual_axes(v::BlockVector) + # dual both v elements and v axes + block_axes = dual.(axes(v)) + return mortar(dual.(blocks(v)), block_axes) end Base.axes(a::GradedUnitRangeDual) = axes(nondual(a)) diff --git a/NDTensors/src/lib/GradedAxes/test/test_dual.jl b/NDTensors/src/lib/GradedAxes/test/test_dual.jl index 98b8838542..f2b3072dc1 100644 --- a/NDTensors/src/lib/GradedAxes/test/test_dual.jl +++ b/NDTensors/src/lib/GradedAxes/test/test_dual.jl @@ -219,14 +219,35 @@ end @test label(ad[Block(2)]) == U1(-1) @test label(ad[Block(2)[1:1]]) == U1(-1) - I = mortar([Block(2)[1:1]]) - g = ad[I] - @test length(g) == 1 - @test label(first(g)) == U1(-1) - @test isdual(g[Block(1)]) + v = ad[[Block(2)[1:1]]] + @test v isa AbstractVector{LabelledInteger{Int64,U1}} + @test length(v) == 1 + @test label(first(v)) == U1(-1) + @test unlabel(first(v)) == 3 + @test isdual(v[Block(1)]) + @test isdual(axes(v, 1)) + @test blocklabels(axes(v, 1)) == [U1(-1)] - @test isdual(axes(ad[[Block(1)]], 1)) # used in view(::BlockSparseVector, [Block(1)]) - @test isdual(axes(ad[mortar([Block(1)[1:1]])], 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]]) + v = ad[mortar([Block(2)[1:1]])] + @test v isa AbstractVector{LabelledInteger{Int64,U1}} + @test isdual(axes(v, 1)) # used in view(::BlockSparseVector, [Block(1)[1:1]]) + @test label(first(v)) == U1(-1) + @test unlabel(first(v)) == 3 + @test blocklabels(axes(v, 1)) == [U1(-1)] + + v = ad[[Block(2)]] + @test v isa AbstractVector{LabelledInteger{Int64,U1}} + @test isdual(axes(v, 1)) # used in view(::BlockSparseVector, [Block(1)]) + @test label(first(v)) == U1(-1) + @test unlabel(first(v)) == 3 + @test blocklabels(axes(v, 1)) == [U1(-1)] + + v = ad[mortar([[Block(2)], [Block(1)]])] + @test v isa AbstractVector{LabelledInteger{Int64,U1}} + @test isdual(axes(v, 1)) + @test label(first(v)) == U1(-1) + @test unlabel(first(v)) == 3 + @test blocklabels(axes(v, 1)) == [U1(-1), U1(0)] end end