Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeParameterAccessors] similartype #1561

Merged
merged 3 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.49"
version = "0.3.50"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
13 changes: 2 additions & 11 deletions NDTensors/src/abstractarray/set_types.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .TypeParameterAccessors: TypeParameterAccessors, set_ndims
using .TypeParameterAccessors: TypeParameterAccessors

"""
# Do we still want to define things like this?
TODO: Use `Accessors.jl` notation:
Expand All @@ -14,13 +15,3 @@ TODO: Use `Accessors.jl` notation:
function TypeParameterAccessors.set_ndims(numbertype::Type{<:Number}, ndims)
return numbertype
end

"""
`set_indstype` should be overloaded for
types with structured dimensions,
like `OffsetArrays` or named indices
(such as ITensors).
"""
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
end
58 changes: 2 additions & 56 deletions NDTensors/src/abstractarray/similar.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype
using Base: DimOrInd, Dims, OneTo
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype, similartype

## Custom `NDTensors.similar` implementation.
## More extensive than `Base.similar`.
Expand Down Expand Up @@ -96,58 +97,3 @@ end
# Use the `size` to determine the dimensions
# NDTensors.similar
similar(array::AbstractArray) = NDTensors.similar(typeof(array), size(array))

## similartype

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
return similartype(similartype(arraytype, eltype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_eltype(arraytype, eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_indstype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray}, dims::DimOrInd...)
return similartype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray})
return similartype(arraytype, eltype(arraytype))
end

## Wrapped arrays
@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
end
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)
6 changes: 5 additions & 1 deletion NDTensors/src/blocksparse/blockdims.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: TypeParameterAccessors

"""
BlockDim

Expand All @@ -18,7 +20,9 @@ const BlockDims{N} = NTuple{N,BlockDim}

Base.ndims(ds::Type{<:BlockDims{N}}) where {N} = N

similartype(::Type{<:BlockDims}, ::Type{Val{N}}) where {N} = BlockDims{N}
function TypeParameterAccessors.similartype(::Type{<:BlockDims}, ::Type{Val{N}}) where {N}
return BlockDims{N}
end

Base.copy(ds::BlockDims) = ds

Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: similartype

#
# BlockSparseTensor (Tensor using BlockSparse storage)
#
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: similartype

export DiagBlockSparse, DiagBlockSparseTensor

# DiagBlockSparse can have either Vector storage, in which case
Expand Down
2 changes: 2 additions & 0 deletions NDTensors/src/blocksparse/similar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: similartype

# NDTensors.similar
function similar(storagetype::Type{<:BlockSparse}, blockoffsets::BlockOffsets, dims::Tuple)
data = similar(datatype(storagetype), nnz(blockoffsets, dims))
Expand Down
4 changes: 3 additions & 1 deletion NDTensors/src/diag/similar.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
using NDTensors.TypeParameterAccessors: TypeParameterAccessors

# NDTensors.similar
function similar(storagetype::Type{<:Diag}, dims::Dims)
return setdata(storagetype, similar(datatype(storagetype), mindim(dims)))
end

# TODO: Redesign UniformDiag to make it handled better
# by generic code.
function similartype(storagetype::Type{<:UniformDiag}, eltype::Type)
function TypeParameterAccessors.similartype(storagetype::Type{<:UniformDiag}, eltype::Type)
# This will also set the `datatype`.
return set_eltype(storagetype, eltype)
end
Expand Down
3 changes: 2 additions & 1 deletion NDTensors/src/dims.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using .DiagonalArrays: DiagonalArrays
using .TypeParameterAccessors: TypeParameterAccessors

export dense, dims, dim, mindim, diaglength

Expand Down Expand Up @@ -52,7 +53,7 @@ dim_to_stride(ds, k::Int) = dim_to_strides(ds)[k]
# code (it helps to construct a Tuple(::NTuple{N,Int}) where the
# only known thing for dispatch is a concrete type such
# as Dims{4})
similartype(::Type{<:Dims}, ::Type{Val{N}}) where {N} = Dims{N}
TypeParameterAccessors.similartype(::Type{<:Dims}, ::Type{Val{N}}) where {N} = Dims{N}

# This is to help with ITensor compatibility
dim(i::Int) = i
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/empty/empty.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
using .TypeParameterAccessors: TypeParameterAccessors, set_eltype, similartype

#
# Represents a tensor order that could be set to any order.
#

struct EmptyOrder end

function similartype(StoreT::Type{<:TensorStorage{EmptyNumber}}, ElT::Type)
function TypeParameterAccessors.similartype(
StoreT::Type{<:TensorStorage{EmptyNumber}}, ElT::Type
)
return set_eltype(StoreT, ElT)
end

function similartype(
function TypeParameterAccessors.similartype(
StoreT::Type{<:TensorStorage{EmptyNumber}}, DataT::Type{<:AbstractArray}
)
return set_datatype(StoreT, DataT)
end

## TODO fix this similartype to use set eltype for BlockSparse
function similartype(
function TypeParameterAccessors.similartype(
::Type{StoreT}, ::Type{ElT}
) where {StoreT<:BlockSparse{EmptyNumber},ElT}
return BlockSparse{ElT,similartype(datatype(StoreT), ElT),ndims(StoreT)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("set_parameters.jl")
include("specify_parameters.jl")
include("default_parameters.jl")
include("base/abstractarray.jl")
include("base/similartype.jl")
include("base/array.jl")
include("base/linearalgebra.jl")
include("base/stridedviews.jl")
Expand Down
62 changes: 62 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
`set_indstype` should be overloaded for
types with structured dimensions,
like `OffsetArrays` or named indices
(such as ITensors).
"""
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
end

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
return similartype(similartype(arraytype, eltype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_eltype(arraytype, eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_indstype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...)
return similartype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray})
return similartype(arraytype, eltype(arraytype))
end

## Wrapped arrays
@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), eltype)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, dims::Tuple
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), dims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
# `FillArray` instead. This is a stand-in
# to make things work with the current design.
function similartype(numbertype::Type{<:Number})
return numbertype
end

# Instances
function similartype(array::AbstractArray, eltype::Type, dims...)
return similartype(typeof(array), eltype, dims...)
end
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)
1 change: 1 addition & 0 deletions NDTensors/src/lib/TypeParameterAccessors/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ using Test: @testset
include("test_defaults.jl")
include("test_custom_types.jl")
include("test_wrappers.jl")
include("test_similartype.jl")
end
end
15 changes: 15 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@eval module $(gensym())
using Test: @test, @test_broken, @testset
using LinearAlgebra: Adjoint
using NDTensors.TypeParameterAccessors: similartype
@testset "TypeParameterAccessors similartype" begin
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
# TODO: Is this a good definition? Probably it should be left unspecified.
@test similartype(Array) == Array{Any}
@test similartype(Array, Float64) == Array{Float64}
@test similartype(Array, (2, 2)) == Matrix
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) ==
Array{Float64,3}
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64}
end
end
2 changes: 1 addition & 1 deletion NDTensors/src/tensor/set_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

# TODO: Modify the `storagetype` according to `inds`, such as the dimensions?
# TODO: Make a version that accepts `indstype::Type`?
function set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
function TypeParameterAccessors.set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
return Tensor{eltype(tensortype),length(inds),storagetype(tensortype),typeof(inds)}
end

Expand Down
6 changes: 4 additions & 2 deletions NDTensors/src/tensor/similar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: TypeParameterAccessors, set_indstype, similartype

# NDTensors.similar
similar(tensor::Tensor) = setstorage(tensor, similar(storage(tensor)))

Expand Down Expand Up @@ -56,11 +58,11 @@ function Base.similar(tensor::Tensor, eltype::Type, dims::Dims)
return NDTensors.similar(tensor, eltype, dims)
end

function similartype(tensortype::Type{<:Tensor}, eltype::Type)
function TypeParameterAccessors.similartype(tensortype::Type{<:Tensor}, eltype::Type)
return set_storagetype(tensortype, similartype(storagetype(tensortype), eltype))
end

function similartype(tensortype::Type{<:Tensor}, dims::Tuple)
function TypeParameterAccessors.similartype(tensortype::Type{<:Tensor}, dims::Tuple)
tensortype_new_inds = set_indstype(tensortype, dims)
# Need to pass `dims` in case that information is needed to make a storage type,
# for example `BlockSparse` needs the number of dimensions.
Expand Down
8 changes: 6 additions & 2 deletions NDTensors/src/tensorstorage/similar.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using .TypeParameterAccessors: TypeParameterAccessors, set_ndims, similartype

# NDTensors.similar
similar(storage::TensorStorage) = setdata(storage, NDTensors.similar(data(storage)))

Expand Down Expand Up @@ -61,14 +63,16 @@ Base.similar(storage::TensorStorage, eltype::Type) = NDTensors.similar(storage,
## Base.similar(storage::TensorStorage, dims::Dims...) = NDTensors.similar(storage, dims...)
## Base.similar(storage::TensorStorage, dims::DimOrInd...) = NDTensors.similar(storage, dims...)

function similartype(storagetype::Type{<:TensorStorage}, eltype::Type)
function TypeParameterAccessors.similartype(
storagetype::Type{<:TensorStorage}, eltype::Type
)
# TODO: Don't convert to an `AbstractVector` with `set_ndims(datatype, 1)`, once we support
# more general data types.
# return set_datatype(storagetype, NDTensors.similartype(datatype(storagetype), eltype))
return set_datatype(storagetype, set_ndims(similartype(datatype(storagetype), eltype), 1))
end

function similartype(storagetype::Type{<:TensorStorage}, dims::Tuple)
function TypeParameterAccessors.similartype(storagetype::Type{<:TensorStorage}, dims::Tuple)
# TODO: In the future, set the dimensions of the data type based on `dims`, once
# more general data types beyond `AbstractVector` are supported.
# `similartype` unwraps any wrapped data.
Expand Down
Loading