Skip to content

Commit

Permalink
[TypeParameterAccessors] Fix similartype(Diagonal)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 6, 2024
1 parent c105287 commit 6b5fc8d
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("unspecify_parameters.jl")
include("set_parameters.jl")
include("specify_parameters.jl")
include("default_parameters.jl")
include("ndims.jl")
include("base/abstractarray.jl")
include("base/similartype.jl")
include("base/array.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ end
function set_ndims(type::Type{<:AbstractArray}, param)
return set_type_parameter(type, ndims, param)
end
function set_ndims(type::Type{<:AbstractArray}, param::NDims)
return set_type_parameter(type, ndims, ndims(param))
end

using SimpleTraits: SimpleTraits, @traitdef, @traitimpl

Expand Down
38 changes: 32 additions & 6 deletions NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@ like `OffsetArrays` or named indices
(such as ITensors).
"""
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
return set_ndims(arraytype, length(dims))
return set_ndims(arraytype, NDims(length(dims)))
end

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, ndims::NDims)
return similartype(similartype(arraytype, eltype), ndims)
end

function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
return similartype(similartype(arraytype, eltype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}
) where {ArrayT; !IsWrappedArray{ArrayT}}
return arraytype
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; !IsWrappedArray{ArrayT}}
Expand All @@ -24,19 +34,29 @@ end
return set_indstype(arraytype, dims)
end

function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...)
return similartype(arraytype, dims)
@traitfn function similartype(
arraytype::Type{ArrayT}, ndims::NDims
) where {ArrayT; !IsWrappedArray{ArrayT}}
return set_ndims(arraytype, ndims)
end

function similartype(arraytype::Type{<:AbstractArray})
return similartype(arraytype, eltype(arraytype))
function similartype(
arraytype::Type{<:AbstractArray}, dim1::Base.DimOrInd, dim_rest::Base.DimOrInd...
)
return similartype(arraytype, (dim1, dim_rest...))
end

## Wrapped arrays
@traitfn function similartype(
arraytype::Type{ArrayT}
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), NDims(arraytype))
end

@traitfn function similartype(
arraytype::Type{ArrayT}, eltype::Type
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), eltype)
return similartype(unwrap_array_type(arraytype), eltype, NDims(arraytype))
end

@traitfn function similartype(
Expand All @@ -45,6 +65,12 @@ end
return similartype(unwrap_array_type(arraytype), dims)
end

@traitfn function similartype(
arraytype::Type{ArrayT}, ndims::NDims
) where {ArrayT; IsWrappedArray{ArrayT}}
return similartype(unwrap_array_type(arraytype), ndims)
end

# This is for uniform `Diag` storage which uses
# a Number as the data type.
# TODO: Delete this when we change to using a
Expand Down
6 changes: 6 additions & 0 deletions NDTensors/src/lib/TypeParameterAccessors/src/ndims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
struct NDims{ndims} end
Base.ndims(::NDims{ndims}) where {ndims} = ndims

NDims(ndims::Integer) = NDims{ndims}()
NDims(arraytype::Type{<:AbstractArray}) = NDims(ndims(arraytype))
NDims(array::AbstractArray) = NDims(typeof(array))
19 changes: 15 additions & 4 deletions NDTensors/src/lib/TypeParameterAccessors/test/test_similartype.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
@eval module $(gensym())
using Test: @test, @test_broken, @testset
using LinearAlgebra: Adjoint
using NDTensors.TypeParameterAccessors: similartype
using LinearAlgebra: Adjoint, Diagonal
using NDTensors.TypeParameterAccessors: NDims, similartype
@testset "TypeParameterAccessors similartype" begin
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
# TODO: Is this a good definition? Probably it should be left unspecified.
@test similartype(Array) == Array{Any}
@test similartype(Array) == Array
@test similartype(Array, Float64) == Array{Float64}
@test similartype(Array, (2, 2)) == Matrix
@test similartype(Array, NDims(2)) == Matrix
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
@test similartype(Array, Float64, NDims(2)) == Matrix{Float64}
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) ==
Array{Float64,3}
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, NDims(3)) == Array{Float64,3}
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64}
@test similartype(Diagonal{Float32,Vector{Float32}}) == Matrix{Float32}
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64) == Matrix{Float64}
@test similartype(Diagonal{Float32,Vector{Float32}}, (2, 2, 2)) == Array{Float32,3}
@test similartype(Diagonal{Float32,Vector{Float32}}, NDims(3)) == Array{Float32,3}
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64, (2, 2, 2)) ==
Array{Float64,3}
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64, NDims(3)) ==
Array{Float64,3}
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using LinearAlgebra:
UnitUpperTriangular,
UpperTriangular
using NDTensors.TypeParameterAccessors:
NDims,
TypeParameter,
is_wrapped_array,
parenttype,
Expand All @@ -33,6 +34,7 @@ include("utils/test_inferred.jl")
@test_inferred set_eltype(array, Float32) array
@test_inferred set_eltype(Array{<:Any,2}, Float64) == Matrix{Float64}
@test_inferred set_ndims(Array{Float64}, 2) == Matrix{Float64} wrapped = true
@test_inferred set_ndims(Array{Float64}, NDims(2)) == Matrix{Float64} wrapped = true
@test_inferred set_ndims(Array{Float64}, TypeParameter(2)) == Matrix{Float64}
@test_inferred unwrap_array_type(array_type) == array_type
end
Expand Down

0 comments on commit 6b5fc8d

Please sign in to comment.