Skip to content

Commit

Permalink
Define mapfoldl/foldl for static arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Feb 29, 2020
1 parent c808bdd commit 621d50a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 72 deletions.
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module StaticArrays
import Base: @_inline_meta, @_propagate_inbounds_meta, @_pure_meta, @propagate_inbounds, @pure

import Base: getindex, setindex!, size, similar, vec, show, length, convert, promote_op,
promote_rule, map, map!, reduce, mapreduce, broadcast,
promote_rule, map, map!, reduce, mapreduce, foldl, mapfoldl, broadcast,
broadcast!, conj, hcat, vcat, ones, zeros, one, reshape, fill, fill!, inv,
iszero, sum, prod, count, any, all, minimum, maximum, extrema,
copy, read, read!, write, reverse
Expand Down
145 changes: 74 additions & 71 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
"""
_InitialValue
A singleton type for representing "universal" initial value (identity element).
The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
version of it by
op′(::_InitialValue, x) = x
op′(acc, x) = op(acc, x)
This is just a conceptually useful model to have in mind and we don't actually
define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
action.
(It is related to that you can always turn a semigroup without an identity into
a monoid by "adjoining" an element that acts as the identity.)
"""
struct _InitialValue end

@inline _first(a1, as...) = a1

################
Expand Down Expand Up @@ -86,28 +106,21 @@ end
## mapreduce ##
###############

@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:,kw...)
_mapreduce(f, op, dims, kw.data, same_size(a, b...), a, b...)
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue())
_mapreduce(f, op, dims, init, same_size(a, b...), a, b...)
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
::Size{S}, a::StaticArray...) where {S}
@inline _mapreduce(args::Vararg{Any,N}) where N = _mapfoldl(args...)

@generated function _mapfoldl(f, op, dims::Colon, init, ::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][1]) for j 1:length(a)]
expr = :(f($(tmp...)))
for i 2:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
return quote
@_inline_meta
@inbounds return $expr
if init === _InitialValue
expr = :(Base.reduce_first(op, $expr))
else
expr = :(op(init, $expr))
end
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray...) where {S}
expr = :(nt.init)
for i 1:prod(S)
for i 2:prod(S)
tmp = [:(a[$j][$i]) for j 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
Expand All @@ -117,24 +130,24 @@ end
end
end

@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
@inline function _mapreduce(f, op, D::Int, init, sz::Size{S}, a::StaticArray) where {S}
# Body of this function is split because constant propagation (at least
# as of Julia 1.2) can't always correctly propagate here and
# as a result the function is not type stable and very slow.
# This makes it at least fast for three dimensions but people should use
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
if D == 1
return _mapreduce(f, op, Val(1), nt, sz, a)
return _mapreduce(f, op, Val(1), init, sz, a)
elseif D == 2
return _mapreduce(f, op, Val(2), nt, sz, a)
return _mapreduce(f, op, Val(2), init, sz, a)
elseif D == 3
return _mapreduce(f, op, Val(3), nt, sz, a)
return _mapreduce(f, op, Val(3), init, sz, a)
else
return _mapreduce(f, op, Val(D), nt, sz, a)
return _mapreduce(f, op, Val(D), init, sz, a)
end
end

@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
@generated function _mapfoldl(f, op, dims::Val{D}, init,
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)
Expand All @@ -143,32 +156,12 @@ end
itr = [1:n for n Snew]
for i Base.product(itr...)
expr = :(f(a[$(i...)]))
for k = 2:S[D]
ik = collect(i)
ik[D] = k
expr = :(op($expr, f(a[$(ik...)])))
if init === _InitialValue
expr = :(Base.reduce_first(op, $expr))
else
expr = :(op(init, $expr))
end

exprs[i...] = expr
end

return quote
@_inline_meta
@inbounds elements = tuple($(exprs...))
@inbounds return similar_type(a, eltype(elements), Size($Snew))(elements)
end
end

@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)

exprs = Array{Expr}(undef, Snew)
itr = [1:n for n = Snew]
for i Base.product(itr...)
expr = :(nt.init)
for k = 1:S[D]
for k = 2:S[D]
ik = collect(i)
ik[D] = k
expr = :(op($expr, f(a[$(ik...)])))
Expand All @@ -188,20 +181,33 @@ end
## reduce ##
############

@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
@inline reduce(op, a::StaticArray; dims = :, init = _InitialValue()) =
_reduce(op, a, dims, init)

# disambiguation
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
Base._typed_vcat(mapreduce(eltype, promote_type, A), A)
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
_reduce(vcat, A, :, NamedTuple())
_reduce(vcat, A, :, _InitialValue())

reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
Base._typed_hcat(mapreduce(eltype, promote_type, A), A)
reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
_reduce(hcat, A, :, NamedTuple())
_reduce(hcat, A, :, _InitialValue())

@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
@inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) =
_mapreduce(identity, op, dims, init, Size(a), a)

################
## (map)foldl ##
################

@inline mapfoldl(f, op::R, a::StaticArray; init = _InitialValue()) where {R} =
_mapfoldl(f, op, :, init, Size(a), a)
@inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} =
_foldl(op, a, :, init)
@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} =
_mapfoldl(identity, op, dims, init, Size(a), a)

#######################
## related functions ##
Expand All @@ -227,37 +233,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)

@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) # avoid ambiguity

@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)

@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, _InitialValue(), Size(a), a)

@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)

@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed)

@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a)

_mean_denom(a, dims::Colon) = length(a)
_mean_denom(a, dims::Int) = size(a, dims)
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)

@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)

@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a)

@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a)

# Diff is slightly different
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
Expand Down Expand Up @@ -286,8 +292,6 @@ end
end
end

struct _InitialValue end

_maybe_val(dims::Integer) = Val(Int(dims))
_maybe_val(dims) = dims
_valof(::Val{D}) where D = D
Expand All @@ -299,19 +303,18 @@ _valof(::Val{D}) where D = D
_accumulate(op, a, _maybe_val(dims), init)

@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
# Adjoin the initial value to `op`:
# Adjoin the initial value to `op` (one-line version of `Base.BottomRF`):
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)

if isempty(a)
T = return_type(rf, Tuple{typeof(init), eltype(a)})
return similar_type(a, T)()
end

# StaticArrays' `reduce` is `foldl`:
results = _reduce(
results = _foldl(
a,
dims,
(init = (similar_type(a, Union{}, Size(0))(), init),),
(similar_type(a, Union{}, Size(0))(), init),
) do (ys, acc), x
y = rf(acc, x)
# Not using `push(ys, y)` here since we need to widen element type as
Expand Down
9 changes: 9 additions & 0 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ using Statistics: mean
@test mapreduce(x->x^2, max, sa; dims=2, init=-1.) == SMatrix{I,1}(mapreduce(x->x^2, max, a, dims=2, init=-1.))
end

@testset "[map]foldl" begin
a = rand(4,3)
v1 = [2,4,6,8]; sv1 = SVector{4}(v1)
@test foldl(+, sv1) === foldl(+, v1)
@test foldl(+, sv1; init=0) === foldl(+, v1; init=0)
@test mapfoldl(-, +, sv1) === mapfoldl(-, +, v1)
@test mapfoldl(-, +, sv1; init=0) === mapfoldl(-, +, v1, init=0)
end

@testset "implemented by [map]reduce and [map]reducedim" begin
I, J, K = 2, 2, 2
OSArray = SArray{Tuple{I,J,K}} # original
Expand Down

0 comments on commit 621d50a

Please sign in to comment.