Skip to content

Commit

Permalink
Reroute (Upper/Lower)Triangular * Diagonal through __muldiag (#55984
Browse files Browse the repository at this point in the history
)

Currently, `::Diagonal * ::AbstractMatrix` calls the method
`LinearAlgebra.__muldiag!` in general that scales the rows, and
similarly for the diagonal on the right. The implementation of
`__muldiag` was duplicating the logic in `LinearAlgebra.modify!` and the
methods for `MulAddMul`. This PR replaces the various branches with
calls to `modify!` instead. I've also extracted the multiplication logic
into its own function `__muldiag_nonzeroalpha!` so that this may be
specialized for matrix types, such as triangular ones.

Secondly, `::Diagonal * ::UpperTriangular` (and similarly, other
triangular matrices) was specialized to forward the multiplication to
the parent of the triangular. For strided matrices, however, it makes
more sense to use the structure and scale only the filled half of the
matrix. Firstly, this improves performance, and secondly, this avoids
errors in case the parent isn't fully initialized corresponding to the
structural zero elements.

Performance improvement:
```julia
julia> D = Diagonal(1:400);

julia> U = UpperTriangular(zeros(size(D)));

julia> @Btime $D * $U;
  314.944 μs (3 allocations: 1.22 MiB) # v"1.12.0-DEV.1288"
  195.960 μs (3 allocations: 1.22 MiB) # This PR
```
Fix:
```julia
julia> M = Matrix{BigFloat}(undef, 2, 2);

julia> M[1,1] = M[2,2] = M[1,2] = 3;

julia> U = UpperTriangular(M)
2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}:
 3.0  3.0
  ⋅   3.0

julia> D = Diagonal(1:2);

julia> U * D # works after this PR
2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}:
 3.0  6.0
  ⋅   6.0
```
  • Loading branch information
jishnub authored Oct 21, 2024
1 parent b01095e commit 04259da
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 66 deletions.
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
_matprod_dest_diag(A, TS) = similar(A, TS)
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
function _matprod_dest_diag(A::SymTridiagonal, TS)
n = size(A, 1)
ev = similar(A, TS, max(0, n-1))
Expand Down
158 changes: 93 additions & 65 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,82 +396,120 @@ function lmul!(D::Diagonal, T::Tridiagonal)
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
out
end
_maybe_unwrap_tri(out, A) = out, A
_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A)
_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A)
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
# if both B and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B)
for j in axes(B, 2)
if isunit
_modify!(_add, D.diag[j] * B[j,j], out, (j,j))
end
rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1))
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
end
out
end
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul)
require_one_based_indexing(out, B)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
if bis0
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha
end
end
else
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
end
end
end
__muldiag_nonzeroalpha!(out, D, B, _add)
end
return out
end
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}

@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
beta = _add.beta
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
@simd for i in axes(A, 1)
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
end
end
out
end
@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
if isunit
_modify!(_add_aisone, A[j,j] * dja, out, (j,j))
end
rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1))
@simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
end
end
out
end
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul)
require_one_based_indexing(out, A)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
if bis0
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja
end
end
else
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja + out[i,j] * beta
end
end
end
__muldiag_nonzeroalpha!(out, A, D, _add)
end
return out
end
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}

@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
d1 = D1.diag
d2 = D2.diag
outd = out.diag
@inbounds @simd for i in eachindex(d1, d2, outd)
_modify!(_add, d1[i] * d2[i], outd, i)
end
out
end
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out.diag, beta)
else
if bis0
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha
end
else
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
end
end
__muldiag_nonzeroalpha!(out, D1, D2, _add)
end
return out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out)
alpha, beta = _add.alpha, _add.beta
mA = size(D1, 1)
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
d1 = D1.diag
d2 = D2.diag
@inbounds @simd for i in eachindex(d1, d2)
_modify!(_add, d1[i] * d2[i], out, (i,i))
end
out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1}
require_one_based_indexing(out)
alpha, beta = _add.alpha, _add.beta
_rmul_or_fill!(out, beta)
if !iszero(alpha)
@inbounds @simd for i in 1:mA
out[i,i] += d1[i] * d2[i] * alpha
end
_add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true)
__muldiag_nonzeroalpha!(out, D1, D2, _add_bis1)
end
return out
end
Expand Down Expand Up @@ -658,31 +696,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
end
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
end
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
# 3-arg ldiv!
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
# 3-arg mul! is disambiguated in special.jl
# 5-arg mul!
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
Expand Down
40 changes: 39 additions & 1 deletion stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ end
@test oneunit(D3) isa typeof(D3)
end

@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
A = randn(4, 4)
TriA = Tri(A)
UTriA = UTri(A)
Expand Down Expand Up @@ -1218,6 +1218,44 @@ end
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)

# we may write to a Unit triangular if the diagonal is preserved
ID = Diagonal(ones(size(UTriA,2)))
@test mul!(copy(UTriA), UTriA, ID) == UTriA
@test mul!(copy(UTriA), ID, UTriA) == UTriA

@testset "partly filled parents" begin
M = Matrix{BigFloat}(undef, 2, 2)
M[1,1] = M[2,2] = 3
isupper = Tri == UpperTriangular
M[1+!isupper, 1+isupper] = 3
D = Diagonal(1:2)
T = Tri(M)
TA = Array(T)
@test T * D == TA * D
@test D * T == D * TA
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T

U = UTri(M)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA

M2 = Matrix{BigFloat}(undef, 2, 2)
M2[1+!isupper, 1+isupper] = 3
U = UTri(M2)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
ID = Diagonal(ones(size(U,2)))
@test mul!(copy(U), U, ID) == U
@test mul!(copy(U), ID, U) == U
@test mul!(copy(U), U, ID, 2, -1) == U
@test mul!(copy(U), ID, U, 2, -1) == U
end
end

struct SMatrix1{T} <: AbstractArray{T,2}
Expand Down

0 comments on commit 04259da

Please sign in to comment.