diff --git a/Project.toml b/Project.toml index e370422..e1bf1e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "NamedDimsArrays" uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" authors = ["ITensor developers and contributors"] -version = "0.3.5" +version = "0.3.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -25,8 +26,9 @@ Adapt = "4.1.1" ArrayLayouts = "1.11.0" BlockArrays = "1.3.0" DerivableInterfaces = "0.3.7" +FillArrays = "1.13.0" LinearAlgebra = "1.10" -MapBroadcast = "0.1.5" +MapBroadcast = "0.1.6" Random = "1.10" SimpleTraits = "0.9.4" TensorAlgebra = "0.1" diff --git a/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl b/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl index 05d7d56..03234ed 100644 --- a/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl +++ b/ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl @@ -2,37 +2,33 @@ module NamedDimsArraysBlockArraysExt using ArrayLayouts: ArrayLayouts using BlockArrays: Block, BlockRange using NamedDimsArrays: - AbstractNamedDimsArray, - AbstractNamedUnitRange, - named_getindex, - nameddims_getindex, - nameddims_view + AbstractNamedDimsArray, AbstractNamedUnitRange, getindex_named, view_nameddims function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1}) # TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead. - return named_getindex(r, I) + return getindex_named(r, I) end function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::BlockRange{1}) # TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead. - return named_getindex(r, I) + return getindex_named(r, I) end const BlockIndex{N} = Union{Block{N},BlockRange{N},AbstractVector{<:Block{N}}} function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...) # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. - return nameddims_view(a, I1, Irest...) + return view_nameddims(a, I1, Irest...) end function Base.view(a::AbstractNamedDimsArray, I::Block) # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. - return nameddims_view(a, Tuple(I)...) + return view_nameddims(a, Tuple(I)...) end function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...) # TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead. - return nameddims_view(a, I1, Irest...) + return view_nameddims(a, I1, Irest...) end # Fix ambiguity error. diff --git a/src/abstractnamedarray.jl b/src/abstractnamedarray.jl index 7e14d24..1f976ff 100644 --- a/src/abstractnamedarray.jl +++ b/src/abstractnamedarray.jl @@ -38,17 +38,17 @@ function Base.hash(a::AbstractNamedArray, h::UInt) return hash(name(a), h) end -named_getindex(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a)) +getindex_named(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a)) # Array funcionality. Base.size(a::AbstractNamedArray) = map(s -> named(s, name(a)), size(dename(a))) Base.axes(a::AbstractNamedArray) = map(s -> named(s, name(a)), axes(dename(a))) Base.eachindex(a::AbstractNamedArray) = eachindex(dename(a)) function Base.getindex(a::AbstractNamedArray{<:Any,N}, I::Vararg{Int,N}) where {N} - return named_getindex(a, I...) + return getindex_named(a, I...) end function Base.getindex(a::AbstractNamedArray, I::Int) - return named_getindex(a, I) + return getindex_named(a, I) end Base.isempty(a::AbstractNamedArray) = isempty(dename(a)) diff --git a/src/abstractnameddimsarray.jl b/src/abstractnameddimsarray.jl index 9b1538e..6430a47 100644 --- a/src/abstractnameddimsarray.jl +++ b/src/abstractnameddimsarray.jl @@ -140,7 +140,7 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange) end function Base.copy(a::AbstractNamedDimsArray) - return nameddimsarraytype(a)(copy(dename(a)), nameddimsindices(a)) + return constructorof(typeof(a))(copy(dename(a)), nameddimsindices(a)) end const NamedDimsIndices = Union{ @@ -181,9 +181,11 @@ Base.values(s::NaiveOrderedSet) = s.values Base.Tuple(s::NaiveOrderedSet) = Tuple(values(s)) Base.length(s::NaiveOrderedSet) = length(values(s)) Base.axes(s::NaiveOrderedSet) = axes(values(s)) +Base.keys(s::NaiveOrderedSet) = Base.OneTo(length(s)) Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(values(s1), values(s2)) Base.iterate(s::NaiveOrderedSet, args...) = iterate(values(s), args...) Base.getindex(s::NaiveOrderedSet, I::Int) = values(s)[I] +Base.get(s::NaiveOrderedSet, I::Integer, default) = get(values(s), I, default) Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(values(s))) Base.Broadcast._axes(::Broadcasted, axes::NaiveOrderedSet) = axes Base.Broadcast.BroadcastStyle(::Type{<:NaiveOrderedSet}) = Style{NaiveOrderedSet}() @@ -210,6 +212,10 @@ function Base.size(a::AbstractNamedDimsArray) return NaiveOrderedSet(map(named, size(dename(a)), nameddimsindices(a))) end +function Base.length(a::AbstractNamedDimsArray) + return prod(size(a); init=1) +end + # Circumvent issue when ndims isn't known at compile time. function Base.axes(a::AbstractNamedDimsArray, d) return d <= ndims(a) ? axes(a)[d] : OneTo(1) @@ -233,17 +239,20 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims) to_nameddimsaxis(ax::NamedDimsAxis) = ax to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I) -nameddimsarraytype(a::AbstractNamedDimsArray) = nameddimsarraytype(typeof(a)) -nameddimsarraytype(a::Type{<:AbstractNamedDimsArray}) = unspecify_type_parameters(a) +# Interface inspired by [ConstructionBase.constructorof](https://github.com/JuliaObjects/ConstructionBase.jl). +constructorof(type::Type{<:AbstractArray}) = unspecify_type_parameters(type) + +constructorof_nameddims(type::Type{<:AbstractNamedDimsArray}) = constructorof(type) +constructorof_nameddims(type::Type{<:AbstractArray}) = NamedDimsArray function similar_nameddims(a::AbstractNamedDimsArray, elt::Type, inds) ax = to_nameddimsaxes(inds) - return nameddimsarraytype(a)(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax)) + return constructorof(typeof(a))(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax)) end function similar_nameddims(a::AbstractArray, elt::Type, inds) ax = to_nameddimsaxes(inds) - return nameddims(similar(a, elt, dename.(Tuple(ax))), name.(ax)) + return constructorof_nameddims(typeof(a))(similar(a, elt, dename.(Tuple(ax))), name.(ax)) end # Base.similar gets the eltype at compile time. @@ -262,7 +271,7 @@ function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet) end function setnameddimsindices(a::AbstractNamedDimsArray, nameddimsindices) - return nameddimsarraytype(a)(dename(a), nameddimsindices) + return constructorof(typeof(a))(dename(a), nameddimsindices) end function replacenameddimsindices(f, a::AbstractNamedDimsArray) return setnameddimsindices(a, replace(f, nameddimsindices(a))) @@ -419,10 +428,18 @@ function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex) setindex!(a, value, to_indices(a, (I,))...) return a end + +function flatten_namedinteger(i::AbstractNamedInteger) + if name(i) isa Union{AbstractNamedUnitRange,AbstractNamedArray} + return name(i)[dename(i)] + end + return i +end + function Base.setindex!( a::AbstractNamedDimsArray, value, I1::AbstractNamedInteger, Irest::AbstractNamedInteger... ) - I = (I1, Irest...) + I = flatten_namedinteger.((I1, Irest...)) # TODO: Check if this permuation should be inverted. perm = getperm(name.(nameddimsindices(a)), name.(I)) # TODO: Throw a `NameMismatch` error. @@ -510,7 +527,9 @@ function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedVi subinds = map(nameddimsindices(a), I) do dimname, i return checked_indexin(dename(i), dename(dimname)) end - return nameddims(view(dename(a), subinds...), sub_nameddimsindices) + return constructorof_nameddims(typeof(a))( + view(dename(a), subinds...), sub_nameddimsindices + ) end function Base.getindex( @@ -522,22 +541,22 @@ end # Repeated definition of `Base.ViewIndex`. const ViewIndex = Union{Real,AbstractArray} -function nameddims_view(a::AbstractArray, I...) +function view_nameddims(a::AbstractArray, I...) sub_dims = filter(dim -> !(I[dim] isa Real), ntuple(identity, ndims(a))) sub_nameddimsindices = map(dim -> nameddimsindices(a, dim)[I[dim]], sub_dims) - return nameddims(view(dename(a), I...), sub_nameddimsindices) + return constructorof(typeof(a))(view(dename(a), I...), sub_nameddimsindices) end function Base.view(a::AbstractNamedDimsArray, I::ViewIndex...) - return nameddims_view(a, I...) + return view_nameddims(a, I...) end -function nameddims_getindex(a::AbstractArray, I...) +function getindex_nameddims(a::AbstractArray, I...) return copy(view(a, I...)) end function Base.getindex(a::AbstractNamedDimsArray, I::ViewIndex...) - return nameddims_getindex(a, I...) + return getindex_nameddims(a, I...) end function Base.setindex!( @@ -556,7 +575,7 @@ function Base.setindex!( Irest::NamedViewIndex..., ) I = (I1, Irest...) - setindex!(a, nameddimsarraytype(a)(value, I), I...) + setindex!(a, constructorof(typeof(a))(value, I), I...) return a end function Base.setindex!( @@ -580,13 +599,13 @@ end function aligndims(a::AbstractArray, dims) new_nameddimsindices = to_nameddimsindices(a, dims) # TODO: Check this permutation is correct (it may be the inverse of what we want). - perm = getperm(nameddimsindices(a), new_nameddimsindices) + perm = Tuple(getperm(nameddimsindices(a), new_nameddimsindices)) isperm(perm) || throw( NameMismatch( "Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)." ), ) - return nameddimsarraytype(a)(permutedims(dename(a), perm), new_nameddimsindices) + return constructorof(typeof(a))(permutedims(dename(a), perm), new_nameddimsindices) end function aligneddims(a::AbstractArray, dims) @@ -598,7 +617,9 @@ function aligneddims(a::AbstractArray, dims) "Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)." ), ) - return nameddimsarraytype(a)(PermutedDimsArray(dename(a), perm), new_nameddimsindices) + return constructorof_nameddims(typeof(a))( + PermutedDimsArray(dename(a), perm), new_nameddimsindices + ) end # Convenient constructors @@ -711,16 +732,17 @@ using Base.Broadcast: broadcasted, check_broadcast_shape, combine_axes -using MapBroadcast: Mapped, mapped +using MapBroadcast: MapBroadcast, Mapped, mapped, tile abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end -struct NamedDimsArrayStyle{N} <: AbstractNamedDimsArrayStyle{N} end -NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N}() -NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N}() +struct NamedDimsArrayStyle{N,NDA} <: AbstractNamedDimsArrayStyle{N} end +NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N,NamedDimsArray}() +NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N,NamedDimsArray}() +NamedDimsArrayStyle{M,NDA}(::Val{N}) where {M,N,NDA} = NamedDimsArrayStyle{N,NDA}() function Broadcast.BroadcastStyle(arraytype::Type{<:AbstractNamedDimsArray}) - return NamedDimsArrayStyle{ndims(arraytype)}() + return NamedDimsArrayStyle{ndims(arraytype),constructorof(arraytype)}() end function Broadcast.combine_axes( @@ -762,6 +784,24 @@ function set_promote_shape( return named.(ax_promoted, name.(ax1)) end +# Handle operations like `ITensor() + ITensor(i, j)`. +# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`, +# or just for `AbstractITensor`. +function set_promote_shape( + ax1::Tuple{}, ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}} +) + return ax2 +end + +# Handle operations like `ITensor(i, j) + ITensor()`. +# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`, +# or just for `AbstractITensor`. +function set_promote_shape( + ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}, ax2::Tuple{} +) + return ax1 +end + function Broadcast.check_broadcast_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet) return set_check_broadcast_shape(Tuple(ax1), Tuple(ax2)) end @@ -775,6 +815,7 @@ function set_check_broadcast_shape( check_broadcast_shape(dename.(ax1), dename.(ax2_aligned)) return nothing end +set_check_broadcast_shape(ax1::Tuple{}, ax2::Tuple{}) = nothing # Dename and lazily permute the arguments using the reference # dimension names. @@ -783,6 +824,20 @@ function denamed(m::Mapped, nameddimsindices) return mapped(m.f, map(arg -> denamed(arg, nameddimsindices), m.args)...) end +function nameddimsarraytype(style::NamedDimsArrayStyle{<:Any,NDA}) where {NDA} + return NDA +end + +using FillArrays: Fill + +function MapBroadcast.tile(a::AbstractNamedDimsArray, ax) + axes(a) == ax && return a + if iszero(ndims(a)) + return constructorof(typeof(a))(Fill(a[], dename.(Tuple(ax))), name.(ax)) + end + return error("Not implemented.") +end + function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax) nameddimsindices = name.(ax) m′ = denamed(Mapped(bc), nameddimsindices) @@ -790,12 +845,12 @@ function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, # wrapper type rather than the generic `nameddims` constructor, which # can lose information. # Call it as `nameddimsarraytype(bc.style)`. - return nameddims(similar(m′, elt, dename.(Tuple(ax))), nameddimsindices) + return nameddimsarraytype(bc.style)( + similar(m′, elt, dename.(Tuple(ax))), nameddimsindices + ) end -function Base.copyto!( - dest::AbstractArray{<:Any,N}, bc::Broadcasted{<:AbstractNamedDimsArrayStyle{N}} -) where {N} +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsArrayStyle}) return copyto!(dest, Mapped(bc)) end diff --git a/src/abstractnamedinteger.jl b/src/abstractnamedinteger.jl index 84e4e79..ed4a4dd 100644 --- a/src/abstractnamedinteger.jl +++ b/src/abstractnamedinteger.jl @@ -57,7 +57,9 @@ struct FusedNames{Names} <: AbstractName names::Names end fusednames(name1, name2) = FusedNames((name1, name2)) -fusednames(name1::FusedNames, name2::FusedNames) = FusedNames(generic_vcat(name1, name2)) +function fusednames(name1::FusedNames, name2::FusedNames) + return FusedNames(generic_vcat(name1.names, name2.names)) +end fusednames(name1, name2::FusedNames) = fusednames(FusedNames((name1,)), name2) fusednames(name1::FusedNames, name2) = fusednames(name1, FusedNames((name2,))) @@ -86,6 +88,8 @@ Base.:-(i::AbstractNamedInteger) = setvalue(i, -dename(i)) # TODO: See if we can delete this. Base.:+(i1::Int, i2::AbstractNamedInteger) = i1 + dename(i2) +Base.:*(i1::Int, i2::AbstractNamedInteger) = named(i1 * dename(i2), name(i2)) + Base.zero(i::AbstractNamedInteger) = setvalue(i, zero(dename(i))) Base.one(i::AbstractNamedInteger) = setvalue(i, one(dename(i))) Base.signbit(i::AbstractNamedInteger) = signbit(dename(i)) diff --git a/src/abstractnamedunitrange.jl b/src/abstractnamedunitrange.jl index 4f40e36..519075a 100644 --- a/src/abstractnamedunitrange.jl +++ b/src/abstractnamedunitrange.jl @@ -16,7 +16,7 @@ named(r::AbstractUnitRange, name) = namedunitrange(r, name) # Derived interface. # TODO: Use `Accessors.@set`? -setname(r::AbstractNamedUnitRange, name) = namedunitrange(dename(r), name) +setname(r::AbstractNamedUnitRange, name) = named(dename(r), name) # TODO: Use `TypeParameterAccessors`. denametype(::Type{<:AbstractNamedUnitRange{<:Any,Value}}) where {Value} = Value @@ -43,17 +43,17 @@ Base.length(r::AbstractNamedUnitRange) = named(length(dename(r)), name(r)) Base.size(r::AbstractNamedUnitRange) = (named(length(dename(r)), name(r)),) Base.axes(r::AbstractNamedUnitRange) = (named(only(axes(dename(r))), name(r)),) Base.step(r::AbstractNamedUnitRange) = named(step(dename(r)), name(r)) -Base.getindex(r::AbstractNamedUnitRange, I::Int) = named_getindex(r, I) +Base.getindex(r::AbstractNamedUnitRange, I::Int) = getindex_named(r, I) # Fix ambiguity error. function Base.getindex(r::AbstractNamedUnitRange, I::AbstractUnitRange{<:Integer}) - return named_getindex(r, I) + return getindex_named(r, I) end # Fix ambiguity error. function Base.getindex(r::AbstractNamedUnitRange, I::Colon) - return named_getindex(r, I) + return getindex_named(r, I) end function Base.getindex(r::AbstractNamedUnitRange, I) - return named_getindex(r, I) + return getindex_named(r, I) end Base.isempty(r::AbstractNamedUnitRange) = isempty(dename(r))