Skip to content

Commit

Permalink
put LinearAlgebra and SparseArrays in extensions (#841)
Browse files Browse the repository at this point in the history
* put LinearAlgebra and SparseArrays in extensions

* typo

* fix folder path

* a typo

* spelling

* imports

* imports
  • Loading branch information
rafaqz authored Nov 5, 2024
1 parent 175d54e commit 093d86f
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 152 deletions.
25 changes: 22 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,28 @@ Interfaces = "85a1e053-f937-4924-92a5-1367d23b7b87"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TableTraits = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extensions]
DimensionalDataAlgebraOfGraphicsExt = "AlgebraOfGraphics"
DimensionalDataCategoricalArraysExt = "CategoricalArrays"
DimensionalDataMakie = "Makie"
DimensionalDataStatsBase = "StatsBase"
DimensionalDataSparseArraysExt = "SparseArrays"
DimensionalDataLinearAlgebraExt = "LinearAlgebra"

[compat]
Adapt = "2, 3.0, 4"
Expand Down Expand Up @@ -96,15 +98,32 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsBase", "StatsPlots", "Test", "Unitful"]
test = [
"AlgebraOfGraphics", "Aqua", "ArrayInterface",
"BenchmarkTools",
"CairoMakie", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations",
"DataFrames", "Distributions", "Documenter",
"GPUArrays",
"ImageFiltering", "ImageTransformations",
"JLArrays",
"LinearAlgebra",
"OffsetArrays",
"Plots",
"Random",
"SafeTestsets", "SparseArrays", "StatsBase", "StatsPlots",
"Test",
"Unitful",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module DimensionalDataLinearAlgebraExt

using DimensionalData
using LinearAlgebra

const DD = DimensionalData

include("matmul.jl")
include("methods.jl")

end
105 changes: 105 additions & 0 deletions ext/DimensionalDataLinearAlgebraExt/matmul.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
using LinearAlgebra: AbstractTriangular, AbstractRotation

using DimensionalData: AnonDim, strict_matmul, comparedims

# Copied from symmetric.jl
const AdjTransVec = Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Any,<:AbstractVector}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}

# Ambiguities
for (a, b) in (
(AbstractDimVector, AbstractDimMatrix),
(AbstractDimMatrix, AbstractDimVector),
(AbstractDimMatrix, AbstractDimMatrix),
(AbstractDimMatrix, AbstractVector),
(AbstractDimVector, AbstractMatrix),
(AbstractDimMatrix, AbstractMatrix),
(AbstractMatrix, AbstractDimVector),
(AbstractVector, AbstractDimMatrix),
(AbstractMatrix, AbstractDimMatrix),
(AbstractDimVector, Adjoint{<:Any,<:AbstractMatrix}),
(AbstractDimVector, AdjTransVec),
(AbstractDimVector, Transpose{<:Any,<:AbstractMatrix}),
(AbstractDimMatrix, Diagonal),
(AbstractDimMatrix, Adjoint{<:Any,<:RealHermSymComplexHerm}),
(AbstractDimMatrix, Adjoint{<:Any,<:AbstractTriangular}),
(AbstractDimMatrix, Transpose{<:Any,<:AbstractTriangular}),
(AbstractDimMatrix, Transpose{<:Any,<:RealHermSymComplexSym}),
(AbstractDimMatrix, AbstractTriangular),
(Diagonal, AbstractDimVector),
(Diagonal, AbstractDimMatrix),
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimVector),
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimMatrix),
(Transpose{<:Any,<:AbstractVector}, AbstractDimVector),
(Transpose{<:Real,<:AbstractVector}, AbstractDimVector),
(Transpose{<:Any,<:AbstractVector}, AbstractDimMatrix),
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimMatrix),
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimVector),
(AbstractTriangular, AbstractDimVector),
(AbstractTriangular, AbstractDimMatrix),
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimVector),
(Adjoint{<:Any,<:AbstractVector}, AbstractDimMatrix),
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimMatrix),
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimMatrix),
(Adjoint{<:Number,<:AbstractVector}, AbstractDimVector{<:Number}),
(AdjTransVec, AbstractDimVector),
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimVector),
)
@eval Base.:*(A::$a, B::$b) = _rebuildmul(A, B)
end

Base.:*(A::AbstractDimVector, B::Adjoint{T,<:AbstractRotation}) where T = _rebuildmul(A, B)
Base.:*(A::Adjoint{T,<:AbstractRotation}, B::AbstractDimMatrix) where T = _rebuildmul(A, B)
Base.:*(A::Transpose{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B)
Base.:*(A::Adjoint{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B)

function _rebuildmul(A::AbstractDimVector, B::AbstractDimMatrix)
# Vector has no dim 2 to compare
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B)),))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimVector)
_comparedims_mul(A, B)
rebuild(A, parent(A) * parent(B), (first(dims(A)),))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimMatrix)
_comparedims_mul(A, B)
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B))))
end
function _rebuildmul(A::AbstractDimVector, B::AbstractMatrix)
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2)))))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractVector)
newdata = parent(A) * B
if newdata isa AbstractArray
rebuild(A, parent(A) * B, (first(dims(A)),))
else
newdata
end
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractMatrix)
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2)))))
end
function _rebuildmul(A::AbstractVector, B::AbstractDimMatrix)
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B))))
end
function _rebuildmul(A::AbstractMatrix, B::AbstractDimVector)
newdata = A * parent(B)
if newdata isa AbstractArray
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(1)),))
else
newdata
end
end
function _rebuildmul(A::AbstractMatrix, B::AbstractDimMatrix)
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B))))
end

function _comparedims_mul(a, b)
# Dont need to compare length if we compare values
isstrict = strict_matmul()
comparedims(last(dims(a)), first(dims(b));
order=isstrict, val=isstrict, length=false
)
end
20 changes: 20 additions & 0 deletions ext/DimensionalDataLinearAlgebraExt/methods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Ambiguity
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::LinearAlgebra.AbstractQ) =
(copyto!(parent(dst), src); dst)

# We need to override copy_similar because our `similar` doesn't work with size changes
# Fixed in Base in https://github.com/JuliaLang/julia/pull/53210
LinearAlgebra.copy_similar(A::AbstractDimArray, ::Type{T}) where {T} = copyto!(similar(A, T), A)

# See methods.jl
@eval begin
@inline LinearAlgebra.Transpose(A::AbstractDimArray{<:Any,2}) =
rebuild(A, LinearAlgebra.Transpose(parent(A)), reverse(dims(A)))
@inline LinearAlgebra.Transpose(A::AbstractDimArray{<:Any,1}) =
rebuild(A, LinearAlgebra.Transpose(parent(A)), (AnonDim(NoLookup(Base.OneTo(1))), dims(A)...))
@inline function LinearAlgebra.Transpose(s::AbstractDimStack)
maplayers(s) do l
ndims(l) > 1 ? LinearAlgebra.Transpose(l) : l
end
end
end
32 changes: 32 additions & 0 deletions ext/DimensionalDataSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module DimensionalDataSparseArraysExt

using DimensionalData
using SparseArrays

# Ambiguity
Base.copyto!(dst::AbstractDimArray{T,2}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::AbstractDimArray{T}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::DimensionalData.AbstractDimArray, src::SparseArrays.CHOLMOD.Dense) =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::SparseArrays.AbstractSparseMatrixCSC) =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::SparseArrays.AbstractCompressedVector, src::AbstractDimArray{T, 1} where T) =
(copyto!(dst, parent(src)); dst)

function Base.copyto!(
dst::AbstractDimArray{<:Any,2},
dst_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}},
src::SparseArrays.AbstractSparseMatrixCSC{<:Any},
src_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}}
)
copyto!(parent(dst), dst_i, src, src_i)
return dst
end
Base.copy!(dst::SparseArrays.AbstractCompressedVector{T}, src::AbstractDimArray{T, 1}) where T =
(copy!(dst, parent(src)); dst)
Base.copy!(dst::SparseArrays.SparseVector, src::AbstractDimArray{T,1}) where T =
(copy!(dst, parent(src)); dst)

end
4 changes: 1 addition & 3 deletions src/DimensionalData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ module DimensionalData

# Standard lib
using Dates,
LinearAlgebra,
Random,
Statistics,
SparseArrays
Statistics

using Base.Broadcast: Broadcasted, BroadcastStyle, DefaultArrayStyle, AbstractArrayStyle,
Unknown
Expand Down
27 changes: 0 additions & 27 deletions src/array/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,33 +313,6 @@ for (d, s) in ((:AbstractDimArray, :AbstractDimArray),
(copyto!(_maybeunwrap(dst), Rdst, _maybeunwrap(src), Rsrc); dst)
end
end
# Ambiguity
Base.copyto!(dst::AbstractDimArray{T,2}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::AbstractDimArray{T}, src::SparseArrays.CHOLMOD.Dense{T}) where T<:Union{Float64,ComplexF64} =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::DimensionalData.AbstractDimArray, src::SparseArrays.CHOLMOD.Dense) =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::SparseArrays.AbstractCompressedVector, src::AbstractDimArray{T, 1} where T) =
(copyto!(dst, parent(src)); dst)
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::SparseArrays.AbstractSparseMatrixCSC) =
(copyto!(parent(dst), src); dst)
Base.copyto!(dst::AbstractDimArray{T,2} where T, src::LinearAlgebra.AbstractQ) =
(copyto!(parent(dst), src); dst)
function Base.copyto!(
dst::AbstractDimArray{<:Any,2},
dst_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}},
src::SparseArrays.AbstractSparseMatrixCSC{<:Any},
src_i::CartesianIndices{2, R} where R<:Tuple{OrdinalRange{Int64, Int64}, OrdinalRange{Int64, Int64}}
)
copyto!(parent(dst), dst_i, src, src_i)
return dst
end
Base.copy!(dst::SparseArrays.AbstractCompressedVector{T}, src::AbstractDimArray{T, 1}) where T =
(copy!(dst, parent(src)); dst)

Base.copy!(dst::SparseArrays.SparseVector, src::AbstractDimArray{T,1}) where T =
(copy!(dst, parent(src)); dst)
Base.copyto!(dst::PermutedDimsArray, src::AbstractDimArray) =
(copyto!(dst, parent(src)); dst)

Expand Down
105 changes: 0 additions & 105 deletions src/array/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra: AbstractTriangular, AbstractRotation

const STRICT_MATMUL_CHECKS = Ref(true)
const STRICT_MATMUL_DOCS = """
With `strict=true` we check [`Lookup`](@ref) [`Order`](@ref) and values
Expand Down Expand Up @@ -27,106 +25,3 @@ Set global matrix multiplication checks to `strict`, or not for all `AbstractDim
$STRICT_MATMUL_DOCS
"""
strict_matmul!(x::Bool) = STRICT_MATMUL_CHECKS[] = x

# Copied from symmetric.jl
const AdjTransVec = Union{Transpose{<:Any,<:AbstractVector},Adjoint{<:Any,<:AbstractVector}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}

# Ambiguities
for (a, b) in (
(AbstractDimVector, AbstractDimMatrix),
(AbstractDimMatrix, AbstractDimVector),
(AbstractDimMatrix, AbstractDimMatrix),
(AbstractDimMatrix, AbstractVector),
(AbstractDimVector, AbstractMatrix),
(AbstractDimMatrix, AbstractMatrix),
(AbstractMatrix, AbstractDimVector),
(AbstractVector, AbstractDimMatrix),
(AbstractMatrix, AbstractDimMatrix),
(AbstractDimVector, Adjoint{<:Any,<:AbstractMatrix}),
(AbstractDimVector, AdjTransVec),
(AbstractDimVector, Transpose{<:Any,<:AbstractMatrix}),
(AbstractDimMatrix, Diagonal),
(AbstractDimMatrix, Adjoint{<:Any,<:RealHermSymComplexHerm}),
(AbstractDimMatrix, Adjoint{<:Any,<:AbstractTriangular}),
(AbstractDimMatrix, Transpose{<:Any,<:AbstractTriangular}),
(AbstractDimMatrix, Transpose{<:Any,<:RealHermSymComplexSym}),
(AbstractDimMatrix, AbstractTriangular),
(Diagonal, AbstractDimVector),
(Diagonal, AbstractDimMatrix),
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimVector),
(Transpose{<:Any,<:AbstractTriangular}, AbstractDimMatrix),
(Transpose{<:Any,<:AbstractVector}, AbstractDimVector),
(Transpose{<:Real,<:AbstractVector}, AbstractDimVector),
(Transpose{<:Any,<:AbstractVector}, AbstractDimMatrix),
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimMatrix),
(Transpose{<:Any,<:RealHermSymComplexSym}, AbstractDimVector),
(AbstractTriangular, AbstractDimVector),
(AbstractTriangular, AbstractDimMatrix),
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimVector),
(Adjoint{<:Any,<:AbstractVector}, AbstractDimMatrix),
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimMatrix),
(Adjoint{<:Any,<:AbstractTriangular}, AbstractDimMatrix),
(Adjoint{<:Number,<:AbstractVector}, AbstractDimVector{<:Number}),
(AdjTransVec, AbstractDimVector),
(Adjoint{<:Any,<:RealHermSymComplexHerm}, AbstractDimVector),
)
@eval Base.:*(A::$a, B::$b) = _rebuildmul(A, B)
end


Base.:*(A::AbstractDimVector, B::Adjoint{T,<:AbstractRotation}) where T = _rebuildmul(A, B)
Base.:*(A::Adjoint{T,<:AbstractRotation}, B::AbstractDimMatrix) where T = _rebuildmul(A, B)
Base.:*(A::Transpose{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B)
Base.:*(A::Adjoint{<:Any,<:AbstractMatrix{T}}, B::AbstractDimArray{S,1}) where {T,S} = _rebuildmul(A, B)

function _rebuildmul(A::AbstractDimVector, B::AbstractDimMatrix)
# Vector has no dim 2 to compare
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B)),))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimVector)
_comparedims_mul(A, B)
rebuild(A, parent(A) * parent(B), (first(dims(A)),))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractDimMatrix)
_comparedims_mul(A, B)
rebuild(A, parent(A) * parent(B), (first(dims(A)), last(dims(B))))
end
function _rebuildmul(A::AbstractDimVector, B::AbstractMatrix)
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2)))))
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractVector)
newdata = parent(A) * B
if newdata isa AbstractArray
rebuild(A, parent(A) * B, (first(dims(A)),))
else
newdata
end
end
function _rebuildmul(A::AbstractDimMatrix, B::AbstractMatrix)
rebuild(A, parent(A) * B, (first(dims(A)), AnonDim(Base.OneTo(size(B, 2)))))
end
function _rebuildmul(A::AbstractVector, B::AbstractDimMatrix)
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B))))
end
function _rebuildmul(A::AbstractMatrix, B::AbstractDimVector)
newdata = A * parent(B)
if newdata isa AbstractArray
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(1)),))
else
newdata
end
end
function _rebuildmul(A::AbstractMatrix, B::AbstractDimMatrix)
rebuild(B, A * parent(B), (AnonDim(Base.OneTo(size(A, 1))), last(dims(B))))
end

function _comparedims_mul(a, b)
# Dont need to compare length if we compare values
isstrict = strict_matmul()
comparedims(last(dims(a)), first(dims(b));
order=isstrict, val=isstrict, length=false
)
end
Loading

0 comments on commit 093d86f

Please sign in to comment.