Skip to content

Commit

Permalink
Introduce KhatriRaoMap and FaceSplittingMap (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Oct 12, 2022
1 parent bd10c6c commit 50f5bf9
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/src/history.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@
such as [`Flux.jl`](https://fluxml.ai/Flux.jl/stable/). The reverse differentiation rule
makes `A::LinearMap` usable as a static, i.e., non-trainable, layer in a network, and
requires the adjoint `A'` of `A` to be defined.
* New map types called `KhatriRaoMap` and `FaceSplittingMap` are introduced. These
correspond to lazy representations of the [column-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
and the [row-wise Kronecker product](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
(or "transposed Khatri-Rao product"), respectively. They can be constructed from two
matrices `A` and `B` via `khatrirao(A, B)` and `facesplitting(A, B)`, respectively.
The first is particularly efficient as it makes use of the vec-trick for Kronecker
products and computes `y = khatrirao(A, B) * x` for a vector `x` as
`y = vec(B * Diagonal(x) * transpose(A))`. As such, the Khatri-Rao product can actually
be built for general `LinearMap`s, including function-based types. Even for moderate
sizes of 5 or more columns, this map-vector product is faster than first creating the
explicit Khatri-Rao product in memory and then multiplying with the vector; not to
mention the memory savings. Unfortunately, similar efficiency cannot be achieved for the
face-splitting product.

## What's new in v3.8

Expand Down
12 changes: 12 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ Type for lazy inverse of another linear map.
LinearMaps.InverseMap
```

### `KhatriRaoMap` and `FaceSplittingMap`

Types for lazy [column-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Column-wise_Kronecker_product)
and [row-wise](https://en.wikipedia.org/wiki/Khatri%E2%80%93Rao_product#Face-splitting_product)
Kronecker product, respectively, also referrerd to
as Khatri-Rao and transposed Khatri-Rao (or face-splitting) product.

```@docs
khatrirao
facesplitting
```

## Methods

### Multiplication methods
Expand Down
3 changes: 2 additions & 1 deletion src/LinearMaps.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LinearMaps

export LinearMap
export , squarekron, kronsum, , sumkronsum
export , squarekron, kronsum, , sumkronsum, khatrirao, facesplitting
export FillMap
export InverseMap

Expand Down Expand Up @@ -344,6 +344,7 @@ include("composition.jl") # composition of linear maps
include("functionmap.jl") # using a function as linear map
include("blockmap.jl") # block linear maps
include("kronecker.jl") # Kronecker product of linear maps
include("khatrirao.jl") # Khatri-Rao and face-splitting products
include("fillmap.jl") # linear maps representing constantly filled matrices
include("embeddedmap.jl") # embedded linear maps
include("conversion.jl") # conversion of linear maps to matrices
Expand Down
99 changes: 99 additions & 0 deletions src/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
struct KhatriRaoMap{T,A<:Tuple{MapOrVecOrMat,MapOrVecOrMat}} <: LinearMap{T}
maps::A
function KhatriRaoMap{T,As}(maps::As) where {T,As<:Tuple{MapOrVecOrMat,MapOrVecOrMat}}
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
@inbounds size(maps[1], 2) == size(maps[2], 2) || throw(ArgumentError("matrices need equal number of columns"))
new{T,As}(maps)
end
end
KhatriRaoMap{T}(maps::As) where {T, As} = KhatriRaoMap{T, As}(maps)

"""
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) -> KhatriRaoMap
Construct a lazy representation of the Khatri-Rao (or column-wise Kronecker) product of two
maps or arrays `A` and `B`. For the application to vectors, the tranpose action of `A` on
vectors needs to be defined.
"""
khatrirao(A::MapOrVecOrMat, B::MapOrVecOrMat) =
KhatriRaoMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))

struct FaceSplittingMap{T,A<:Tuple{AbstractMatrix,AbstractMatrix}} <: LinearMap{T}
maps::A
function FaceSplittingMap{T,As}(maps::As) where {T,As<:Tuple{AbstractMatrix,AbstractMatrix}}
@assert promote_type(T, map(eltype, maps)...) == T "eltype $(eltype(A)) cannot be promoted to $T in KhatriRaoMap constructor"
@inbounds size(maps[1], 1) == size(maps[2], 1) || throw(ArgumentError("matrices need equal number of columns, got $(size(maps[1], 1)) and $(size(maps[2], 1))"))
new{T,As}(maps)
end
end
FaceSplittingMap{T}(maps::As) where {T, As} = FaceSplittingMap{T, As}(maps)

"""
facesplitting(A::AbstractMatrix, B::AbstractMatrix) -> FaceSplittingMap
Construct a lazy representation of the face-splitting (or row-wise Kronecker) product of
two matrices `A` and `B`.
"""
facesplitting(A::AbstractMatrix, B::AbstractMatrix) =
FaceSplittingMap{Base.promote_op(*, eltype(A), eltype(B))}((A, B))

Base.size(K::KhatriRaoMap) = ((A, B) = K.maps; (size(A, 1) * size(B, 1), size(A, 2)))
Base.size(K::FaceSplittingMap) = ((A, B) = K.maps; (size(A, 1), size(A, 2) * size(B, 2)))
Base.adjoint(K::KhatriRaoMap) = facesplitting(map(adjoint, K.maps)...)
Base.adjoint(K::FaceSplittingMap) = khatrirao(map(adjoint, K.maps)...)
Base.transpose(K::KhatriRaoMap) = facesplitting(map(transpose, K.maps)...)
Base.transpose(K::FaceSplittingMap) = khatrirao(map(transpose, K.maps)...)

LinearMaps.MulStyle(::Union{KhatriRaoMap,FaceSplittingMap}) = FiveArg()

function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector)
A, B = K.maps
Y = reshape(y, (size(B, 1), size(A, 1)))
if size(B, 1) <= size(A, 1)
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A))
else
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))))
end
return y
end
function _unsafe_mul!(y, K::KhatriRaoMap, x::AbstractVector, α, β)
A, B = K.maps
Y = reshape(y, (size(B, 1), size(A, 1)))
if size(B, 1) <= size(A, 1)
mul!(Y, convert(Matrix, B * Diagonal(x)), transpose(A), α, β)
else
mul!(Y, B, transpose(convert(Matrix, A * transpose(Diagonal(x)))), α, β)
end
return y
end

function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector)
A, B = K.maps
@inbounds for m in eachindex(y)
y[m] = zero(eltype(y))
l = firstindex(x)
for i in axes(A, 2)
ai = A[m,i]
@simd for k in axes(B, 2)
y[m] += ai*B[m,k]*x[l]
l += 1
end
end
end
return y
end
function _unsafe_mul!(y, K::FaceSplittingMap, x::AbstractVector, α, β)
A, B = K.maps
@inbounds for m in eachindex(y)
y[m] *= β
l = firstindex(x)
for i in axes(A, 2)
ai = A[m,i]
@simd for k in axes(B, 2)
y[m] += ai*B[m,k]*x[l]*α
l += 1
end
end
end
return y
end
6 changes: 6 additions & 0 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ function _unsafe_mul!(y, L::OuterProductMap, x::AbstractVector)
a, bt = L.maps
mul!(y, a.lmap, bt.lmap * x)
end
function _unsafe_mul!(y, L::KroneckerMap{<:Any,<:Tuple{VectorMap,VectorMap}}, x::AbstractVector)
a, b = L.maps
kron!(y, a.lmap, b.lmap)
rmul!(y, first(x))
return y
end
function _unsafe_mul!(y, L::KroneckerMap2, x::AbstractVector)
require_one_based_indexing(y)
A, B = L.maps
Expand Down
26 changes: 26 additions & 0 deletions test/khatrirao.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Test, LinearMaps, LinearAlgebra

@testset "KhatriRaoMap & FaceSplittingMap" begin
for trans in (identity, complex), m in (2, 4)
A = collect(reshape(trans(1:6), 3, 2))
B = collect(reshape(trans(1:2m), m, 2))
K = @inferred khatrirao(A, B)
@test facesplitting(A', B')' === K
M = mapreduce(kron, hcat, eachcol(A), eachcol(B))
Mx = mapreduce((a, b) -> kron(permutedims(a), permutedims(b)), vcat, eachrow(A'), eachrow(B'))
@test size(K) == size(M)
@test size(@inferred adjoint(K)) == reverse(size(K))
@test size(@inferred transpose(K)) == reverse(size(K))
@test Matrix(K) == M
@test Matrix(K') == Mx
@test LinearMaps.MulStyle(K) === LinearMaps.MulStyle(K') === LinearMaps.FiveArg()
@test (K')' === K
@test transpose(transpose(K)) === K
x = trans(rand(-10:10, size(K, 2)))
y = trans(rand(-10:10, size(K, 1)))
for α in (false, true, trans(rand(2:5))), β in (false, true, trans(rand(2:5)))
@test mul!(copy(y), K, x, α, β) == y * β + K * x * α
@test mul!(copy(x), K', y, α, β) == x * β + K' * y * α
end
end
end
3 changes: 3 additions & 0 deletions test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
@test Matrix(L) M
@test L * v M * v
end
L = ones(3) (b = rand(ComplexF64, 4))
@test L * [2] kron(ones(3), b) * 2
@test Matrix(L) kron(ones(3), b) rtol=2eps(Float64)
L = ones(3) ones(ComplexF64, 4)'
v = rand(4)
@test Matrix(L) == ones(3,4)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ include("getindex.jl")
include("inversemap.jl")

include("rrules.jl")

include("khatrirao.jl")

0 comments on commit 50f5bf9

Please sign in to comment.