diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl index 39b4e885ac..357949f80d 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl @@ -589,3 +589,28 @@ macro view!(expr) @capture(expr, array_[indices__]) return :(view!($(esc(array)), $(esc.(indices)...))) end + +# SVD additions +# ------------- +using LinearAlgebra: Algorithm +using BlockArrays: BlockedMatrix + +# svd first calls `eigencopy_oftype` to create something that can be in-place SVD'd +# Here, we hijack this system to determine if there is any structure we can exploit +# default: SVD is most efficient with BlockedArray +function eigencopy_oftype(A::AbstractBlockArray, S) + return BlockedMatrix{S}(A) +end + +function svd!(A::BlockedMatrix; full::Bool=false, alg::Algorithm=default_svd_alg(A)) + F = svd!(parent(A); full, alg) + + # restore block pattern + m = length(F.S) + bax1, bax2, bax3 = axes(A, 1), blockedrange([m]), axes(A, 2) + + u = BlockedArray(F.U, (bax1, bax2)) + s = BlockedVector(F.S, (bax2,)) + vt = BlockedArray(F.Vt, (bax2, bax3)) + return SVD(u, s, vt) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index d0a1e4cdd7..683b3600a7 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -1,5 +1,13 @@ module BlockSparseArrays + +# factorizations +include("factorizations/svd.jl") +include("factorizations/tsvd.jl") + +# possible upstream contributions include("BlockArraysExtensions/BlockArraysExtensions.jl") + +# interface functions that don't have to specialize include("blocksparsearrayinterface/blocksparsearrayinterface.jl") include("blocksparsearrayinterface/linearalgebra.jl") include("blocksparsearrayinterface/blockzero.jl") @@ -7,6 +15,8 @@ include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/map.jl") include("blocksparsearrayinterface/arraylayouts.jl") include("blocksparsearrayinterface/views.jl") + +# functions defined for any abstractblocksparsearray include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") @@ -17,8 +27,12 @@ include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("abstractblocksparsearray/linearalgebra.jl") + +# functions specifically for BlockSparseArray include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") +include("blocksparsearray/blockdiagonalarray.jl") + include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl") include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsematrix.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsematrix.jl index 0c2c578781..8741ec80a5 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsematrix.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/abstractblocksparsematrix.jl @@ -1 +1,22 @@ const AbstractBlockSparseMatrix{T} = AbstractBlockSparseArray{T,2} + +# SVD is implemented by trying to +# 1. Attempt to find a block-diagonal implementation by permuting +# 2. Fallback to AbstractBlockArray implementation via BlockedArray +function svd( + A::AbstractBlockSparseMatrix; full::Bool=false, alg::Algorithm=default_svd_alg(A) +) + T = LinearAlgebra.eigtype(eltype(A)) + A′ = try_to_blockdiagonal(A) + + if isnothing(A′) + # not block-diagonal, fall back to dense case + Adense = eigencopy_oftype(A, T) + return svd!(Adense; full, alg) + end + + # compute block-by-block and permute back + A″, (I, J) = A′ + F = svd!(eigencopy_oftype(A″, T); full, alg) + return SVD(F.U[Block.(I), Block.(J)], F.S, F.Vt) +end \ No newline at end of file diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blockdiagonalarray.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blockdiagonalarray.jl new file mode 100644 index 0000000000..e20eb8b0c7 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearray/blockdiagonalarray.jl @@ -0,0 +1,59 @@ +# type alias for block-diagonal +using LinearAlgebra: Diagonal + +const BlockDiagonal{T,A,Axes,V<:AbstractVector{A}} = BlockSparseMatrix{ + T,A,Diagonal{A,V},Axes +} + +function BlockDiagonal(blocks::AbstractVector{<:AbstractMatrix}) + return BlockSparseArray( + Diagonal(blocks), (blockedrange(size.(blocks, 1)), blockedrange(size.(blocks, 2))) + ) +end + +# Cast to block-diagonal implementation if permuted-blockdiagonal +function try_to_blockdiagonal_perm(A) + inds = map(x -> Int.(Tuple(x)), vec(collect(block_stored_indices(A)))) + I = first.(inds) + allunique(I) || return nothing + J = last.(inds) + p = sortperm(J) + Jsorted = J[p] + allunique(Jsorted) || return nothing + return Block.(I[p], Jsorted) +end + +""" + try_to_blockdiagonal(A) + +Attempt to find a permutation of blocks that makes `A` blockdiagonal. If unsuccesful, +returns nothing, otherwise returns both the blockdiagonal `B` as well as the permutation `I, J`. +""" +function try_to_blockdiagonal(A::AbstractBlockSparseMatrix) + perm = try_to_blockdiagonal_perm(A) + isnothing(perm) && return perm + I = first.(Tuple.(perm)) + J = last.(Tuple.(perm)) + diagblocks = map(invperm(I), J) do i, j + return A[Block(i, j)] + end + return BlockDiagonal(diagblocks), perm +end + +# SVD implementation +function eigencopy_oftype(A::BlockDiagonal, S) + diag = map(Base.Fix2(eigencopy_oftype, S), A.blocks.diag) + return BlockDiagonal(diag) +end + +function svd(A::BlockDiagonal; kwargs...) + return svd!(eigencopy_oftype(A, LinearAlgebra.eigtype(eltype(A))); kwargs...) +end +function svd!(A::BlockDiagonal; full::Bool=false, alg::Algorithm=default_svd_alg(A)) + # TODO: handle full + F = map(a -> svd!(a; full, alg), blocks(A).diag) + Us = map(Base.Fix2(getproperty, :U), F) + Ss = map(Base.Fix2(getproperty, :S), F) + Vts = map(Base.Fix2(getproperty, :Vt), F) + return SVD(BlockDiagonal(Us), mortar(Ss), BlockDiagonal(Vts)) +end \ No newline at end of file diff --git a/NDTensors/src/lib/BlockSparseArrays/src/factorizations/svd.jl b/NDTensors/src/lib/BlockSparseArrays/src/factorizations/svd.jl new file mode 100644 index 0000000000..feba45ca35 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/factorizations/svd.jl @@ -0,0 +1,206 @@ +using LinearAlgebra: + LinearAlgebra, Factorization, Algorithm, default_svd_alg, Adjoint, Transpose +using BlockArrays: AbstractBlockMatrix, BlockedArray, BlockedMatrix, BlockedVector +using BlockArrays: BlockLayout + +# Singular Value Decomposition: +# need new type to deal with U and V having possible different types +# this is basically a carbon copy of the LinearAlgebra implementation. +# additionally, by default we implement a fallback to the LinearAlgebra implementation +# in hope to support as many foreign types as possible that chose to extend those methods. + +# TODO: add this to MatrixFactorizations +# TODO: decide where this goes +# TODO: decide whether or not to restrict types to be blocked. +""" + SVD <: Factorization + +Matrix factorization type of the singular value decomposition (SVD) of a matrix `A`. +This is the return type of [`svd(_)`](@ref), the corresponding matrix factorization function. + +If `F::SVD` is the factorization object, `U`, `S`, `V` and `Vt` can be obtained +via `F.U`, `F.S`, `F.V` and `F.Vt`, such that `A = U * Diagonal(S) * Vt`. +The singular values in `S` are sorted in descending order. + +Iterating the decomposition produces the components `U`, `S`, and `V`. + +# Examples +```jldoctest +julia> A = [1. 0. 0. 0. 2.; 0. 0. 3. 0. 0.; 0. 0. 0. 0. 0.; 0. 2. 0. 0. 0.] +4×5 Matrix{Float64}: + 1.0 0.0 0.0 0.0 2.0 + 0.0 0.0 3.0 0.0 0.0 + 0.0 0.0 0.0 0.0 0.0 + 0.0 2.0 0.0 0.0 0.0 + +julia> F = svd(A) +SVD{Float64, Float64, Matrix{Float64}, Vector{Float64}} +U factor: +4×4 Matrix{Float64}: + 0.0 1.0 0.0 0.0 + 1.0 0.0 0.0 0.0 + 0.0 0.0 0.0 1.0 + 0.0 0.0 -1.0 0.0 +singular values: +4-element Vector{Float64}: + 3.0 + 2.23606797749979 + 2.0 + 0.0 +Vt factor: +4×5 Matrix{Float64}: + -0.0 0.0 1.0 -0.0 0.0 + 0.447214 0.0 0.0 0.0 0.894427 + 0.0 -1.0 0.0 0.0 0.0 + 0.0 0.0 0.0 1.0 0.0 + +julia> F.U * Diagonal(F.S) * F.Vt +4×5 Matrix{Float64}: + 1.0 0.0 0.0 0.0 2.0 + 0.0 0.0 3.0 0.0 0.0 + 0.0 0.0 0.0 0.0 0.0 + 0.0 2.0 0.0 0.0 0.0 + +julia> u, s, v = F; # destructuring via iteration + +julia> u == F.U && s == F.S && v == F.V +true +``` +""" +struct SVD{T,Tr,M<:AbstractArray{T},C<:AbstractVector{Tr},N<:AbstractArray{T}} <: + Factorization{T} + U::M + S::C + Vt::N + function SVD{T,Tr,M,C,N}( + U, S, Vt + ) where {T,Tr,M<:AbstractArray{T},C<:AbstractVector{Tr},N<:AbstractArray{T}} + Base.require_one_based_indexing(U, S, Vt) + return new{T,Tr,M,C,N}(U, S, Vt) + end +end +function SVD(U::AbstractArray{T}, S::AbstractVector{Tr}, Vt::AbstractArray{T}) where {T,Tr} + return SVD{T,Tr,typeof(U),typeof(S),typeof(Vt)}(U, S, Vt) +end +function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) where {T,Tr} + return SVD( + convert(AbstractArray{T}, U), + convert(AbstractVector{Tr}, S), + convert(AbstractArray{T}, Vt), + ) +end + +function SVD{T}(F::SVD) where {T} + return SVD( + convert(AbstractMatrix{T}, F.U), + convert(AbstractVector{real(T)}, F.S), + convert(AbstractMatrix{T}, F.Vt), + ) +end +LinearAlgebra.Factorization{T}(F::SVD) where {T} = SVD{T}(F) + +# iteration for destructuring into components +Base.iterate(S::SVD) = (S.U, Val(:S)) +Base.iterate(S::SVD, ::Val{:S}) = (S.S, Val(:V)) +Base.iterate(S::SVD, ::Val{:V}) = (S.V, Val(:done)) +Base.iterate(::SVD, ::Val{:done}) = nothing + +function Base.getproperty(F::SVD, d::Symbol) + if d === :V + return getfield(F, :Vt)' + else + return getfield(F, d) + end +end + +function Base.propertynames(F::SVD, private::Bool=false) + return private ? (:V, fieldnames(typeof(F))...) : (:U, :S, :V, :Vt) +end + +Base.size(A::SVD, dim::Integer) = dim == 1 ? size(A.U, dim) : size(A.Vt, dim) +Base.size(A::SVD) = (size(A, 1), size(A, 2)) + +function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::SVD) + summary(io, F) + println(io) + println(io, "U factor:") + show(io, mime, F.U) + println(io, "\nsingular values:") + show(io, mime, F.S) + println(io, "\nVt factor:") + return show(io, mime, F.Vt) +end + +Base.adjoint(usv::SVD) = SVD(adjoint(usv.Vt), usv.S, adjoint(usv.U)) +Base.transpose(usv::SVD) = SVD(transpose(usv.Vt), usv.S, transpose(usv.U)) + +# Conversion +Base.AbstractMatrix(F::SVD) = (F.U * Diagonal(F.S)) * F.Vt +Base.AbstractArray(F::SVD) = AbstractMatrix(F) +Base.Matrix(F::SVD) = Array(AbstractArray(F)) +Base.Array(F::SVD) = Matrix(F) +SVD(usv::SVD) = usv +SVD(usv::LinearAlgebra.SVD) = SVD(usv.U, usv.S, usv.Vt) + +# functions default to LinearAlgebra +# ---------------------------------- +""" + svd!(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD + +`svd!` is the same as [`svd`](@ref), but saves space by +overwriting the input `A`, instead of creating a copy. See documentation of [`svd`](@ref) for details. +""" +svd!(A; kwargs...) = SVD(LinearAlgebra.svd!(A; kwargs...)) + +""" + svd(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD + +Compute the singular value decomposition (SVD) of `A` and return an `SVD` object. + +`U`, `S`, `V` and `Vt` can be obtained from the factorization `F` with `F.U`, +`F.S`, `F.V` and `F.Vt`, such that `A = U * Diagonal(S) * Vt`. +The algorithm produces `Vt` and hence `Vt` is more efficient to extract than `V`. +The singular values in `S` are sorted in descending order. + +Iterating the decomposition produces the components `U`, `S`, and `V`. + +If `full = false` (default), a "thin" SVD is returned. For an ``M +\\times N`` matrix `A`, in the full factorization `U` is ``M \\times M`` +and `V` is ``N \\times N``, while in the thin factorization `U` is ``M +\\times K`` and `V` is ``N \\times K``, where ``K = \\min(M,N)`` is the +number of singular values. + +`alg` specifies which algorithm and LAPACK method to use for SVD: +- `alg = DivideAndConquer()` (default): Calls `LAPACK.gesdd!`. +- `alg = QRIteration()`: Calls `LAPACK.gesvd!` (typically slower but more accurate) . + +!!! compat "Julia 1.3" + The `alg` keyword argument requires Julia 1.3 or later. + +# Examples +```jldoctest +julia> A = rand(4,3); + +julia> F = svd(A); # Store the Factorization Object + +julia> A ≈ F.U * Diagonal(F.S) * F.Vt +true + +julia> U, S, V = F; # destructuring via iteration + +julia> A ≈ U * Diagonal(S) * V' +true + +julia> Uonly, = svd(A); # Store U only + +julia> Uonly == U +true +``` +""" +svd(A; kwargs...) = + SVD(svd!(eigencopy_oftype(A, LinearAlgebra.eigtype(eltype(A))); kwargs...)) + +LinearAlgebra.svdvals(usv::SVD{<:Any,T}) where {T} = (usv.S)::Vector{T} + +# Added here to avoid type-piracy +eigencopy_oftype(A, S) = LinearAlgebra.eigencopy_oftype(A, S) \ No newline at end of file diff --git a/NDTensors/src/lib/BlockSparseArrays/src/factorizations/tsvd.jl b/NDTensors/src/lib/BlockSparseArrays/src/factorizations/tsvd.jl new file mode 100644 index 0000000000..b4c3251f54 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/factorizations/tsvd.jl @@ -0,0 +1,104 @@ +# Truncation schemes +# ------------------ +""" + TuncationScheme + +Abstract supertype for all truncated factorization schemes. +See also [`notrunc`](@ref), [`truncdim`](@ref) and [`truncbelow`](@ref). +""" +abstract type TruncationScheme end + +""" + NoTruncation <: TruncationScheme + notrunc() + +Truncation algorithm that represents no truncation. See also [`notrunc`](@ref) for easily +constructing instances of this type. +""" +struct NoTruncation <: TruncationScheme end +notrunc() = NoTruncation() + +""" + TruncAmount <: TruncationScheme + truncdim(num::Int; by=identity, lt=isless, rev=true) + +Truncation algorithm that truncates by keeping the first `num` singular values, sorted using +the kwargs `by`, `lt` and `rev` as passed to the `Base.sort` algorithm. +""" +struct TruncAmount{B,L} <: TruncationScheme + num::Int + by::B + lt::L + rev::Bool +end +truncdim(n::Int; by=identity, lt=isless, rev=true) = TruncAmount(n, by, lt, rev) + +""" + TruncFilter <: TruncationScheme + truncbelow(ϵ::Real) + +Truncation algorithm that truncates by filter, where `truncbelow` filters all values below a threshold `ϵ`. +""" +struct TruncFilter{F} <: TruncationScheme + f::F +end +function truncbelow(ϵ::Real) + @assert ϵ ≥ zero(ϵ) + return TruncFilter(≥(ϵ)) +end + +""" + truncate(F::Factorization; trunc::TruncationScheme) -> Factorization + +Truncate a factorization using the given truncation algorithm: +- `trunc = notrunc()` (default): Does nothing. +- `trunc = truncdim(n)`: Keeps the largest `n` values. +- `trunc = truncbelow(ϵ)`: Truncates all values below a threshold `ϵ`. +""" +truncate(F::SVD; trunc::TruncationScheme=notrunc()) = _truncate(F, trunc) + +# use _truncate to dispatch on `trunc` +_truncate(usv, ::NoTruncation) = usv + +# note: kept implementations separate for possible future ambiguity reasons +function _truncate(usv::SVD, trunc::TruncAmount) + keep = select_values(usv.S, trunc) + return SVD(usv.U[:, keep], usv.S[keep], usv.Vt[keep, :]) +end +function _truncate(usv::SVD, trunc::TruncFilter) + keep = select_values(usv.S, trunc) + return SVD(usv.U[:, keep], usv.S[keep], usv.Vt[keep, :]) +end + +function select_values(S, trunc::TruncAmount) + return partialsortperm(S, 1:(trunc.num); trunc.lt, trunc.by, trunc.rev) +end +select_values(S, trunc::TruncFilter) = findall(trunc.f, S) + +# For convenience, also add a method to both truncate and decompose +""" + tsvd(A; full::Bool=false, alg=default_svd_alg(A), trunc=notrunc()) + +Compute the truncated singular value decomposition (SVD) of `A`. +This is typically achieved by first computing the full SVD, followed by a filtering based on +the computed singular values. +""" +function tsvd(A; kwargs...) + return tsvd!(eigencopy_oftype(A, LinearAlgebra.eigtype(eltype(A))); kwargs...) +end + +""" + tsvd!(A; full::Bool=false, alg=default_svd_alg(A), trunc=notrunc()) + +Compute the truncated singular value decomposition (SVD) of `A`, saving space by +overwriting `A` in the process. See documentation of [`tsvd`](@ref) for details. +""" +function tsvd!(A; full::Bool=false, alg=default_svd_alg(A), trunc=notrunc()) + return _tsvd!(A, alg, trunc, full) +end + +# default implementation simply dispatches through to `svd` and `truncate`. +function _tsvd!(A, alg, trunc, full) + F = svd!(A; alg, full) + return truncate(F; trunc) +end \ No newline at end of file diff --git a/NDTensors/src/lib/BlockSparseArrays/test/indexing.jl b/NDTensors/src/lib/BlockSparseArrays/test/indexing.jl new file mode 100644 index 0000000000..2e7fa21fd4 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/test/indexing.jl @@ -0,0 +1,56 @@ +using BlockArrays: + Block, + BlockIndexRange, + BlockRange, + BlockSlice, + BlockVector, + BlockedOneTo, + BlockedUnitRange, + BlockedVector, + blockedrange, + blocklength, + blocklengths, + blocksize, + blocksizes, + blockaxes, + mortar +using LinearAlgebra: Adjoint, mul!, norm +using NDTensors.BlockSparseArrays: + @view!, + BlockSparseArray, + BlockView, + block_nstored, + block_reshape, + block_stored_indices, + view! +using NDTensors.SparseArrayInterface: nstored +using NDTensors.TensorAlgebra: contract + +using Test + +T = Float64 + +# scalar indexing +a = BlockSparseArray{T}([2, 3], [2, 2]) +for i in blockaxes(a, 1), j in blockaxes(a, 2) + a[i, j] = randn(T, blocksizes(a)[Int(i), Int(j)]) +end +a + +a[1, 2] +a[Block(1, 1)] +a[Block.(1:2), Block(1)] +aslice = a[Block.(1:2), 1] +axes(aslice, 1) +axes(aslice, 2) +length(axes(aslice)) == ndims(aslice) + +aslice = a[Block.(1:2), 1:3] +axes(aslice) + +mask = trues(size(a, 2)) +aslice = a[:, mask] +aslice = a[:, [1, 2]] + +a[Block(1, 1)] = randn(T, 2, 2) +a[Block(2, 2)] = randn(T, 2, 2) diff --git a/NDTensors/src/lib/BlockSparseArrays/test/svd_tests.jl b/NDTensors/src/lib/BlockSparseArrays/test/svd_tests.jl new file mode 100644 index 0000000000..74453dcd7a --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/test/svd_tests.jl @@ -0,0 +1,164 @@ +using Test +using NDTensors.BlockSparseArrays +using NDTensors.BlockSparseArrays: + BlockSparseArray, svd, tsvd, notrunc, truncbelow, truncdim, BlockDiagonal +using BlockArrays +using LinearAlgebra: LinearAlgebra, Diagonal, svdvals + +function test_svd(a, usv) + U, S, V = usv + + @test U * Diagonal(S) * V' ≈ a + @test U' * U ≈ LinearAlgebra.I + @test V' * V ≈ LinearAlgebra.I +end + +# regular matrix +# -------------- +sizes = ((3, 3), (4, 3), (3, 4)) +eltypes = (Float32, Float64, ComplexF64) +@testset "($m, $n) Matrix{$T}" for ((m, n), T) in Iterators.product(sizes, eltypes) + a = rand(m, n) + usv = @inferred svd(a) + test_svd(a, usv) + + # TODO: type unstable? + usv2 = tsvd(a) + test_svd(a, usv2) + + usv3 = tsvd(a; trunc=truncdim(2)) + @test length(usv3.S) == 2 + @test usv3.U' * usv3.U ≈ LinearAlgebra.I + @test usv3.Vt * usv3.V ≈ LinearAlgebra.I + + s = usv3.S[end] + usv4 = tsvd(a; trunc=truncbelow(s)) + @test length(usv4.S) == 2 + @test usv4.U' * usv4.U ≈ LinearAlgebra.I + @test usv4.Vt * usv4.V ≈ LinearAlgebra.I +end + +# block matrix +# ------------ +blockszs = (([2, 2], [2, 2]), ([2, 2], [2, 3]), ([2, 2, 1], [2, 3]), ([2, 3], [2])) +@testset "($m, $n) BlockMatrix{$T}" for ((m, n), T) in Iterators.product(blockszs, eltypes) + a = mortar([rand(T, i, j) for i in m, j in n]) + usv = svd(a) + test_svd(a, usv) + @test usv.U isa BlockedMatrix + @test usv.Vt isa BlockedMatrix + @test usv.S isa BlockedVector + + usv2 = tsvd(a) + test_svd(a, usv2) + @test usv.U isa BlockedMatrix + @test usv.Vt isa BlockedMatrix + @test usv.S isa BlockedVector + + usv3 = tsvd(a; trunc=truncdim(2)) + @test length(usv3.S) == 2 + @test usv3.U' * usv3.U ≈ LinearAlgebra.I + @test usv3.Vt * usv3.V ≈ LinearAlgebra.I + @test usv.U isa BlockedMatrix + @test usv.Vt isa BlockedMatrix + @test usv.S isa BlockedVector + + s = usv3.S[end] + usv4 = tsvd(a; trunc=truncbelow(s)) + @test length(usv4.S) == 2 + @test usv4.U' * usv4.U ≈ LinearAlgebra.I + @test usv4.Vt * usv4.V ≈ LinearAlgebra.I + @test usv.U isa BlockedMatrix + @test usv.Vt isa BlockedMatrix + @test usv.S isa BlockedVector +end + +# Block-Diagonal matrices +# ----------------------- +@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in + Iterators.product(blockszs, eltypes) + a = BlockDiagonal([rand(T, i, j) for (i, j) in zip(m, n)]) + usv = svd(a) + # TODO: `BlockDiagonal * Adjoint` errors + test_svd(a, usv) + @test usv.U isa BlockDiagonal + @test usv.Vt isa BlockDiagonal + @test usv.S isa BlockVector + + usv2 = tsvd(a) + test_svd(a, usv2) + @test usv.U isa BlockDiagonal + @test usv.Vt isa BlockDiagonal + @test usv.S isa BlockVector + + # TODO: need to find a slicing fix to make this work + # usv3 = tsvd(a; trunc=truncdim(2)) + # @test length(usv3.S) == 2 + # @test usv3.U' * usv3.U ≈ LinearAlgebra.I + # @test usv3.Vt * usv3.V ≈ LinearAlgebra.I + # @test usv.U isa BlockDiagonal + # @test usv.Vt isa BlockDiagonal + # @test usv.S isa BlockVector + + # @show s = usv3.S[end] + # usv4 = tsvd(a; trunc=truncbelow(s)) + # @test length(usv4.S) == 2 + # @test usv4.U' * usv4.U ≈ LinearAlgebra.I + # @test usv4.Vt * usv4.V ≈ LinearAlgebra.I + # @test usv.U isa BlockDiagonal + # @test usv.Vt isa BlockDiagonal + # @test usv.S isa BlockVector +end + +a = mortar([rand(2, 2) for i in 1:2, j in 1:3]) +usv = svd(a) +test_svd(a, usv) + +a = mortar([rand(2, 2) for i in 1:3, j in 1:2]) +usv = svd(a) +test_svd(a, usv) + +# blocksparse +# ----------- +@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in + Iterators.product(blockszs, eltypes) + a = BlockSparseArray{T}(m, n) + for i in LinearAlgebra.diagind(blocks(a)) + I = CartesianIndices(blocks(a))[i] + a[Block(I.I...)] = rand(T, size(blocks(a)[i])) + end + perm = Random.randperm(length(m)) + a = a[Block.(perm), Block.(1:length(n))] + + # errors because `blocks(a)[CartesianIndex.(...)]` is not implemented + usv = svd(a) + # TODO: `BlockDiagonal * Adjoint` errors + # test_svd(a, usv) + @test usv.U isa BlockDiagonal + @test usv.Vt isa BlockDiagonal + @test usv.S isa BlockVector + + # usv2 = tsvd(a) + test_svd(a, usv2) + @test usv.U isa BlockDiagonal + @test usv.Vt isa BlockDiagonal + @test usv.S isa BlockVector + + # TODO: need to find a slicing fix to make this work + # usv3 = tsvd(a; trunc=truncdim(2)) + # @test length(usv3.S) == 2 + # @test usv3.U' * usv3.U ≈ LinearAlgebra.I + # @test usv3.Vt * usv3.V ≈ LinearAlgebra.I + # @test usv.U isa BlockDiagonal + # @test usv.Vt isa BlockDiagonal + # @test usv.S isa BlockVector + + # @show s = usv3.S[end] + # usv4 = tsvd(a; trunc=truncbelow(s)) + # @test length(usv4.S) == 2 + # @test usv4.U' * usv4.U ≈ LinearAlgebra.I + # @test usv4.Vt * usv4.V ≈ LinearAlgebra.I + # @test usv.U isa BlockDiagonal + # @test usv.Vt isa BlockDiagonal + # @test usv.S isa BlockVector +end \ No newline at end of file