Skip to content

Commit

Permalink
Revert "Reroute (Upper/Lower)Triangular * Diagonal through `__muldi…
Browse files Browse the repository at this point in the history
…ag` (#55…"

This reverts commit 04259da.
  • Loading branch information
jishnub authored Oct 21, 2024
1 parent 04259da commit c7dbacb
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 134 deletions.
2 changes: 0 additions & 2 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,6 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
_matprod_dest_diag(A, TS) = similar(A, TS)
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
function _matprod_dest_diag(A::SymTridiagonal, TS)
n = size(A, 1)
ev = similar(A, TS, max(0, n-1))
Expand Down
158 changes: 65 additions & 93 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,120 +396,82 @@ function lmul!(D::Diagonal, T::Tridiagonal)
return T
end

@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
end
end
out
end
_maybe_unwrap_tri(out, A) = out, A
_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A)
_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A)
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
# if both B and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B)
for j in axes(B, 2)
if isunit
_modify!(_add, D.diag[j] * B[j,j], out, (j,j))
end
rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1))
@inbounds @simd for i in rowrange
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
end
end
out
end
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul)
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
__muldiag_nonzeroalpha!(out, D, B, _add)
end
return out
end

@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
beta = _add.beta
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
@simd for i in axes(A, 1)
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
end
end
out
end
@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
beta = _add.beta
# since alpha is multiplied to the diagonal element of D,
# we may skip alpha in the second multiplication by setting ais1 to true
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
# if both A and out have the same upper/lower triangular structure,
# we may directly read and write from the parents
out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A)
@inbounds for j in axes(A, 2)
dja = _add(D.diag[j])
if isunit
_modify!(_add_aisone, A[j,j] * dja, out, (j,j))
end
rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1))
@simd for i in rowrange
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
if bis0
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha
end
end
else
@inbounds for j in axes(B, 2)
@simd for i in axes(B, 1)
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
end
end
end
end
out
return out
end
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul)
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, A)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out, beta)
else
__muldiag_nonzeroalpha!(out, A, D, _add)
if bis0
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja
end
end
else
@inbounds for j in axes(A, 2)
dja = D.diag[j] * alpha
@simd for i in axes(A, 1)
out[i,j] = A[i,j] * dja + out[i,j] * beta
end
end
end
end
return out
end

@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
d1 = D1.diag
d2 = D2.diag
outd = out.diag
@inbounds @simd for i in eachindex(d1, d2, outd)
_modify!(_add, d1[i] * d2[i], outd, i)
end
out
end
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
alpha, beta = _add.alpha, _add.beta
if iszero(alpha)
_rmul_or_fill!(out.diag, beta)
else
__muldiag_nonzeroalpha!(out, D1, D2, _add)
if bis0
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha
end
else
@inbounds @simd for i in eachindex(out.diag)
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
end
end
end
return out
end
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
d1 = D1.diag
d2 = D2.diag
@inbounds @simd for i in eachindex(d1, d2)
_modify!(_add, d1[i] * d2[i], out, (i,i))
end
out
end
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1}
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out)
alpha, beta = _add.alpha, _add.beta
mA = size(D1, 1)
d1 = D1.diag
d2 = D2.diag
_rmul_or_fill!(out, beta)
if !iszero(alpha)
_add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true)
__muldiag_nonzeroalpha!(out, D1, D2, _add_bis1)
@inbounds @simd for i in 1:mA
out[i,i] += d1[i] * d2[i] * alpha
end
end
return out
end
Expand Down Expand Up @@ -696,21 +658,31 @@ for Tri in (:UpperTriangular, :LowerTriangular)
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
end
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
@invoke *(A::AbstractMatrix, D::Diagonal)
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
end
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
@invoke *(D::Diagonal, A::AbstractMatrix)
# 3-arg ldiv!
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
# 3-arg mul! is disambiguated in special.jl
# 5-arg mul!
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, D, A.data, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
α, β = _add.alpha, _add.beta
iszero(α) && return _rmul_or_fill!(C, β)
diag′ = bis0 ? nothing : diag(C)
data = mul!(C.data, A.data, D, α, β)
$Tri(_setdiag!(data, _add, D.diag, diag′))
end
end

@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
Expand Down
40 changes: 1 addition & 39 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ end
@test oneunit(D3) isa typeof(D3)
end

@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
A = randn(4, 4)
TriA = Tri(A)
UTriA = UTri(A)
Expand Down Expand Up @@ -1218,44 +1218,6 @@ end
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)

# we may write to a Unit triangular if the diagonal is preserved
ID = Diagonal(ones(size(UTriA,2)))
@test mul!(copy(UTriA), UTriA, ID) == UTriA
@test mul!(copy(UTriA), ID, UTriA) == UTriA

@testset "partly filled parents" begin
M = Matrix{BigFloat}(undef, 2, 2)
M[1,1] = M[2,2] = 3
isupper = Tri == UpperTriangular
M[1+!isupper, 1+isupper] = 3
D = Diagonal(1:2)
T = Tri(M)
TA = Array(T)
@test T * D == TA * D
@test D * T == D * TA
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T

U = UTri(M)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA

M2 = Matrix{BigFloat}(undef, 2, 2)
M2[1+!isupper, 1+isupper] = 3
U = UTri(M2)
UA = Array(U)
@test U * D == UA * D
@test D * U == D * UA
ID = Diagonal(ones(size(U,2)))
@test mul!(copy(U), U, ID) == U
@test mul!(copy(U), ID, U) == U
@test mul!(copy(U), U, ID, 2, -1) == U
@test mul!(copy(U), ID, U, 2, -1) == U
end
end

struct SMatrix1{T} <: AbstractArray{T,2}
Expand Down

0 comments on commit c7dbacb

Please sign in to comment.