From 6b5fc8df7661468d5fb8570a90730fc31ce397bc Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 6 Nov 2024 12:51:50 -0500 Subject: [PATCH] [TypeParameterAccessors] Fix similartype(Diagonal) --- .../src/TypeParameterAccessors.jl | 1 + .../src/base/abstractarray.jl | 3 ++ .../src/base/similartype.jl | 38 ++++++++++++++++--- .../lib/TypeParameterAccessors/src/ndims.jl | 6 +++ .../test/test_similartype.jl | 19 ++++++++-- .../test/test_wrappers.jl | 2 + 6 files changed, 59 insertions(+), 10 deletions(-) create mode 100644 NDTensors/src/lib/TypeParameterAccessors/src/ndims.jl diff --git a/NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl b/NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl index adafed6567..3333cd417b 100644 --- a/NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl +++ b/NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl @@ -12,6 +12,7 @@ include("unspecify_parameters.jl") include("set_parameters.jl") include("specify_parameters.jl") include("default_parameters.jl") +include("ndims.jl") include("base/abstractarray.jl") include("base/similartype.jl") include("base/array.jl") diff --git a/NDTensors/src/lib/TypeParameterAccessors/src/base/abstractarray.jl b/NDTensors/src/lib/TypeParameterAccessors/src/base/abstractarray.jl index 14286b65cb..89e38d3b2f 100644 --- a/NDTensors/src/lib/TypeParameterAccessors/src/base/abstractarray.jl +++ b/NDTensors/src/lib/TypeParameterAccessors/src/base/abstractarray.jl @@ -13,6 +13,9 @@ end function set_ndims(type::Type{<:AbstractArray}, param) return set_type_parameter(type, ndims, param) end +function set_ndims(type::Type{<:AbstractArray}, param::NDims) + return set_type_parameter(type, ndims, ndims(param)) +end using SimpleTraits: SimpleTraits, @traitdef, @traitimpl diff --git a/NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl b/NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl index c51bf80b6e..f6fed09885 100644 --- a/NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl +++ b/NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl @@ -5,13 +5,23 @@ like `OffsetArrays` or named indices (such as ITensors). """ function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple) - return set_ndims(arraytype, length(dims)) + return set_ndims(arraytype, NDims(length(dims))) +end + +function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, ndims::NDims) + return similartype(similartype(arraytype, eltype), ndims) end function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple) return similartype(similartype(arraytype, eltype), dims) end +@traitfn function similartype( + arraytype::Type{ArrayT} +) where {ArrayT; !IsWrappedArray{ArrayT}} + return arraytype +end + @traitfn function similartype( arraytype::Type{ArrayT}, eltype::Type ) where {ArrayT; !IsWrappedArray{ArrayT}} @@ -24,19 +34,29 @@ end return set_indstype(arraytype, dims) end -function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...) - return similartype(arraytype, dims) +@traitfn function similartype( + arraytype::Type{ArrayT}, ndims::NDims +) where {ArrayT; !IsWrappedArray{ArrayT}} + return set_ndims(arraytype, ndims) end -function similartype(arraytype::Type{<:AbstractArray}) - return similartype(arraytype, eltype(arraytype)) +function similartype( + arraytype::Type{<:AbstractArray}, dim1::Base.DimOrInd, dim_rest::Base.DimOrInd... +) + return similartype(arraytype, (dim1, dim_rest...)) end ## Wrapped arrays +@traitfn function similartype( + arraytype::Type{ArrayT} +) where {ArrayT; IsWrappedArray{ArrayT}} + return similartype(unwrap_array_type(arraytype), NDims(arraytype)) +end + @traitfn function similartype( arraytype::Type{ArrayT}, eltype::Type ) where {ArrayT; IsWrappedArray{ArrayT}} - return similartype(unwrap_array_type(arraytype), eltype) + return similartype(unwrap_array_type(arraytype), eltype, NDims(arraytype)) end @traitfn function similartype( @@ -45,6 +65,12 @@ end return similartype(unwrap_array_type(arraytype), dims) end +@traitfn function similartype( + arraytype::Type{ArrayT}, ndims::NDims +) where {ArrayT; IsWrappedArray{ArrayT}} + return similartype(unwrap_array_type(arraytype), ndims) +end + # This is for uniform `Diag` storage which uses # a Number as the data type. # TODO: Delete this when we change to using a diff --git a/NDTensors/src/lib/TypeParameterAccessors/src/ndims.jl b/NDTensors/src/lib/TypeParameterAccessors/src/ndims.jl new file mode 100644 index 0000000000..f640763583 --- /dev/null +++ b/NDTensors/src/lib/TypeParameterAccessors/src/ndims.jl @@ -0,0 +1,6 @@ +struct NDims{ndims} end +Base.ndims(::NDims{ndims}) where {ndims} = ndims + +NDims(ndims::Integer) = NDims{ndims}() +NDims(arraytype::Type{<:AbstractArray}) = NDims(ndims(arraytype)) +NDims(array::AbstractArray) = NDims(typeof(array)) diff --git a/NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl b/NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl index 703f9ba1e1..0813fc917b 100644 --- a/NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl +++ b/NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl @@ -1,15 +1,26 @@ @eval module $(gensym()) using Test: @test, @test_broken, @testset -using LinearAlgebra: Adjoint -using NDTensors.TypeParameterAccessors: similartype +using LinearAlgebra: Adjoint, Diagonal +using NDTensors.TypeParameterAccessors: NDims, similartype @testset "TypeParameterAccessors similartype" begin @test similartype(Array, Float64, (2, 2)) == Matrix{Float64} - # TODO: Is this a good definition? Probably it should be left unspecified. - @test similartype(Array) == Array{Any} + @test similartype(Array) == Array @test similartype(Array, Float64) == Array{Float64} @test similartype(Array, (2, 2)) == Matrix + @test similartype(Array, NDims(2)) == Matrix + @test similartype(Array, Float64, (2, 2)) == Matrix{Float64} + @test similartype(Array, Float64, NDims(2)) == Matrix{Float64} @test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) == Array{Float64,3} + @test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, NDims(3)) == Array{Float64,3} @test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64} + @test similartype(Diagonal{Float32,Vector{Float32}}) == Matrix{Float32} + @test similartype(Diagonal{Float32,Vector{Float32}}, Float64) == Matrix{Float64} + @test similartype(Diagonal{Float32,Vector{Float32}}, (2, 2, 2)) == Array{Float32,3} + @test similartype(Diagonal{Float32,Vector{Float32}}, NDims(3)) == Array{Float32,3} + @test similartype(Diagonal{Float32,Vector{Float32}}, Float64, (2, 2, 2)) == + Array{Float64,3} + @test similartype(Diagonal{Float32,Vector{Float32}}, Float64, NDims(3)) == + Array{Float64,3} end end diff --git a/NDTensors/src/lib/TypeParameterAccessors/test/test_wrappers.jl b/NDTensors/src/lib/TypeParameterAccessors/test/test_wrappers.jl index d253f92c77..ed8e223922 100644 --- a/NDTensors/src/lib/TypeParameterAccessors/test/test_wrappers.jl +++ b/NDTensors/src/lib/TypeParameterAccessors/test/test_wrappers.jl @@ -11,6 +11,7 @@ using LinearAlgebra: UnitUpperTriangular, UpperTriangular using NDTensors.TypeParameterAccessors: + NDims, TypeParameter, is_wrapped_array, parenttype, @@ -33,6 +34,7 @@ include("utils/test_inferred.jl") @test_inferred set_eltype(array, Float32) ≈ array @test_inferred set_eltype(Array{<:Any,2}, Float64) == Matrix{Float64} @test_inferred set_ndims(Array{Float64}, 2) == Matrix{Float64} wrapped = true + @test_inferred set_ndims(Array{Float64}, NDims(2)) == Matrix{Float64} wrapped = true @test_inferred set_ndims(Array{Float64}, TypeParameter(2)) == Matrix{Float64} @test_inferred unwrap_array_type(array_type) == array_type end