From cd28425977ddaa44c3cae88d59ab3a75c787054c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 22 Nov 2023 12:35:04 +0100 Subject: [PATCH] Improve broadcasting of `PDMat` and `PDiagMat` (#197) --- Project.toml | 2 +- src/pdiagmat.jl | 3 +++ src/pdmat.jl | 3 +++ test/pdmtypes.jl | 28 ++++++++++++++++++++++++++++ test/specialarrays.jl | 16 ++++++++++++++++ 5 files changed, 51 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 64011c8..d9b4481 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "PDMats" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.29" +version = "0.11.30" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/pdiagmat.jl b/src/pdiagmat.jl index a14cf3f..cb43a70 100644 --- a/src/pdiagmat.jl +++ b/src/pdiagmat.jl @@ -32,6 +32,9 @@ Base.Matrix(a::PDiagMat) = Matrix(Diagonal(a.diag)) LinearAlgebra.diag(a::PDiagMat) = copy(a.diag) LinearAlgebra.cholesky(a::PDiagMat) = Cholesky(Diagonal(map(sqrt, a.diag)), 'U', 0) +### Treat as a `Diagonal` matrix in broadcasting since that is better supported +Base.broadcastable(a::PDiagMat) = Base.broadcastable(Diagonal(a.diag)) + ### Inheriting from AbstractMatrix function Base.getindex(a::PDiagMat, i::Integer) diff --git a/src/pdmat.jl b/src/pdmat.jl index 0e44780..abaec69 100644 --- a/src/pdmat.jl +++ b/src/pdmat.jl @@ -46,6 +46,9 @@ Base.Matrix(a::PDMat) = Matrix(a.mat) LinearAlgebra.diag(a::PDMat) = diag(a.mat) LinearAlgebra.cholesky(a::PDMat) = a.chol +### Work with the underlying matrix in broadcasting +Base.broadcastable(a::PDMat) = Base.broadcastable(a.mat) + ### Inheriting from AbstractMatrix Base.getindex(a::PDMat, i::Int) = getindex(a.mat, i) diff --git a/test/pdmtypes.jl b/test/pdmtypes.jl index 471b826..c749cf8 100644 --- a/test/pdmtypes.jl +++ b/test/pdmtypes.jl @@ -256,4 +256,32 @@ using Test @test_throws DimensionMismatch PDSparseMat(A[1:(end - 1), 1:(end - 1)], C) end end + + @testset "Subtraction" begin + # This falls back to the generic method in Julia based on broadcasting + dim = 4 + x = rand(dim, dim) + A = PDMat(x' * x + I) + @test Base.broadcastable(A) == A.mat + + B = PDiagMat(rand(dim)) + @test Base.broadcastable(B) == Diagonal(B.diag) + + for X in (A, B), Y in (A, B) + @test X - Y isa (X === Y === B ? Diagonal{Float64, Vector{Float64}} : Matrix{Float64}) + @test X - Y ≈ Matrix(X) - Matrix(Y) + end + + C = ScalMat(dim, rand()) + @test A - C isa Matrix{Float64} + @test A - C ≈ Matrix(A) - Matrix(C) + @test C - A isa Matrix{Float64} + @test C - A ≈ Matrix(C) - Matrix(A) + + # ScalMat does not behave nicely with PDiagMat + @test_broken B - C isa Diagonal{Float64, Vector{Float64}} + @test B - C ≈ Matrix(B) - Matrix(C) + @test_broken C - B isa Diagonal{Float64, Vector{Float64}} + @test C - B ≈ Matrix(C) - Matrix(B) + end end diff --git a/test/specialarrays.jl b/test/specialarrays.jl index 75468ba..be6cbe8 100644 --- a/test/specialarrays.jl +++ b/test/specialarrays.jl @@ -84,6 +84,22 @@ using StaticArrays @test Xt_invA_X(A, Y) isa Symmetric{Float64,<:SMatrix{10, 10, Float64}} @test Xt_invA_X(A, Y) ≈ Matrix(Y)' * (Matrix(A) \ Matrix(Y)) end + + # Subtraction falls back to the generic method in Base which is based on broadcasting + @test Base.broadcastable(PDS) == PDS.mat + @test Base.broadcastable(D) == Diagonal(D.diag) + for A in (PDS, D), B in (PDS, D) + @test A - B isa SMatrix{4, 4, Float64} + @test A - B ≈ Matrix(A) - Matrix(B) + end + + # ScalMat does not behave nicely with broadcasting currently + for A in (PDS, D) + @test_broken A - E isa SMatrix{4, 4, Float64} + @test_broken E - A isa SMatrix{4, 4, Float64} + @test A - E ≈ Matrix(A) - Matrix(E) + @test E - A ≈ Matrix(E) - Matrix(A) + end end @testset "BandedMatrices" begin