From a29e523febfe8b6ecbebfb0f651fd757e8a1a8eb Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Tue, 5 Oct 2021 19:53:09 +0200 Subject: [PATCH 1/7] Make `view(::AbstractWeights, ...)` return an `AbstractWeights` This is necessary to preserve the information regarding the type of weights. --- src/weights.jl | 6 ++++++ test/weights.jl | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/weights.jl b/src/weights.jl index 34fe4cd77..111f19a55 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -37,6 +37,12 @@ end Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv)) +@propagate_inbounds function Base.view(wv::W, inds...) where {W <: AbstractWeights} + @boundscheck checkbounds(wv, inds...) + @inbounds v = view(wv.values, inds...) + W(v, sum(v)) +end + @propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int) s = v - wv[i] wv.values[i] = v diff --git a/test/weights.jl b/test/weights.jl index 7735e04f7..b84ff2c33 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -37,7 +37,7 @@ weight_funcs = (weights, aweights, fweights, pweights) @test sum(sa, wv) === 7.0 end -@testset "$f, setindex!" for f in weight_funcs +@testset "$f, getindex, view, setindex!" for f in weight_funcs w = [1., 2., 3.] wv = f(w) @@ -46,6 +46,17 @@ end @test sum(wv) === 6. @test wv == w + @test wv[[1, 3]] == w[[1, 3]] + @test typeof(wv[[1, 3]]) === typeof(wv) + @test sum( wv[[1, 3]]) === sum(w[[1, 3]]) + + # Check view + @test view(wv, [1, 3]) == view(wv, [true, false, true]) == w[[1, 3]] + @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) === typeof(wv) + @test sum(view(wv, [1, 3])) === sum(view(wv, [true, false, true])) === sum(w[[1, 3]]) + @test_throws BoundsError view(wv, [1, 5]) + @test_throws BoundsError view(wv, [true, false, true, true]) + # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val @test wv[1] === 4. # value correctly converted and set From 04dd2385fc09c6b5c98fbc998837a367b9a24fa5 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 7 Nov 2021 14:21:33 +0100 Subject: [PATCH 2/7] Make `view` return a `SubArray` of `AbstractWeights` --- docs/src/weights.md | 1 + src/weights.jl | 14 +++++++++++--- test/pairwise.jl | 10 ++++++++-- test/weights.jl | 7 ++++++- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/docs/src/weights.md b/docs/src/weights.md index 50f6c1bcc..7a2a4ef52 100644 --- a/docs/src/weights.md +++ b/docs/src/weights.md @@ -8,6 +8,7 @@ In statistical applications, it is not uncommon to assign weights to samples. To !!! note - The weight vector is a light-weight wrapper of the input vector. The input vector is NOT copied during construction. - The weight vector maintains the sum of weights, which is computed upon construction. If the value of the sum is pre-computed, one can supply it as the second argument to the constructor and save the time of computing the sum again. + - Views of weight vectors are also `AbstractWeights`, but their sum has to be recomputed by each call to `sum` as weights in the parent vector may have been mutated. ## Implementations diff --git a/src/weights.jl b/src/weights.jl index 111f19a55..838f2b552 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -14,6 +14,7 @@ macro weights(name) sum::S end $(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values)) + $(esc(:weightstype))(::Type{<:$(esc(name))}) = $(esc(name)) end end @@ -37,12 +38,19 @@ end Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv)) -@propagate_inbounds function Base.view(wv::W, inds...) where {W <: AbstractWeights} +@propagate_inbounds function Base.view(wv::W, inds...) where {S <: Real, W <: AbstractWeights{S}} @boundscheck checkbounds(wv, inds...) - @inbounds v = view(wv.values, inds...) - W(v, sum(v)) + @inbounds v = invoke(view, Tuple{AbstractArray, Vararg{Any}}, + wv, inds...) + # Sum is not actually used but compute the right type for clarity + weightstype(W)(v, zero(S)) end +# Always recompute the sum for views of AbstractWeights, as we cannot know whether +# the parent array has been mutated +Base.sum(wv::AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}) where + {S<:Real, T<:Real} = sum(wv.values) + @propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int) s = v - wv[i] wv.values[i] = v diff --git a/test/pairwise.jl b/test/pairwise.jl index 09699b276..fae07b55c 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -39,8 +39,14 @@ arbitrary_fun(x, y) = cor(x, y) @inferred pairwise(f, x, y) - @test_throws Union{ArgumentError,MethodError} pairwise(f, [Int[]], [Int[]]) - @test_throws Union{ArgumentError,MethodError} pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) + if f === cov + @test pairwise(f, [Int[]], [Int[]]) == + pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) == + reshape([0.0], 1, 1) + else + @test_throws Union{ArgumentError,MethodError} pairwise(f, [Int[]], [Int[]]) + @test_throws Union{ArgumentError,MethodError} pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) + end res = pairwise(f, [], []) @test size(res) == (0, 0) diff --git a/test/weights.jl b/test/weights.jl index b84ff2c33..2030a1003 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -52,10 +52,15 @@ end # Check view @test view(wv, [1, 3]) == view(wv, [true, false, true]) == w[[1, 3]] - @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) === typeof(wv) + @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) <: StatsBase.weightstype(typeof(wv)) @test sum(view(wv, [1, 3])) === sum(view(wv, [true, false, true])) === sum(w[[1, 3]]) @test_throws BoundsError view(wv, [1, 5]) @test_throws BoundsError view(wv, [true, false, true, true]) + v = view(wv, [1, 3]) + wv[1] += 1 + @test sum(v) == sum(view(wv, [1, 3])) + v[1] -= 1 + @test wv[1] == 1 # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val From 192ccca63ce15843ba060e036df4f41c5d085022 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 7 Nov 2021 15:49:33 +0100 Subject: [PATCH 3/7] Make `copy` preserve weights type --- src/weights.jl | 17 ++++++++++++++--- test/weights.jl | 11 ++++++++++- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 838f2b552..a7c75febe 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -36,7 +36,10 @@ end W(v, sum(v)) end -Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values), sum(wv)) +Base.getindex(wv::AbstractWeights, ::Colon) = copy(wv) + +Base.copy(wv::W) where {W <: AbstractWeights} = + weightstype(W)(copy(wv.values), sum(wv)) @propagate_inbounds function Base.view(wv::W, inds...) where {S <: Real, W <: AbstractWeights{S}} @boundscheck checkbounds(wv, inds...) @@ -51,6 +54,10 @@ end Base.sum(wv::AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}) where {S<:Real, T<:Real} = sum(wv.values) +Base.copy(wv::W) where + {S<:Real, T<:Real, W<:AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}} = + weightstype(W)(copy(view(parent(wv.values).values, parentindices(wv.values)...)), wv.sum) + @propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int) s = v - wv[i] wv.values[i] = v @@ -351,8 +358,12 @@ end for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights) @eval begin - Base.isequal(x::$w, y::$w) = isequal(x.sum, y.sum) && isequal(x.values, y.values) - Base.:(==)(x::$w, y::$w) = (x.sum == y.sum) && (x.values == y.values) + Base.isequal(x::$w, y::$w) = + (x.values isa SubArray || y.values isa SubArray || isequal(x.values, y.values)) && + isequal(x.values, y.values) + Base.:(==)(x::$w, y::$w) = + (x.values isa SubArray || y.values isa SubArray || (x.sum == y.sum)) && + (x.values == y.values) end end diff --git a/test/weights.jl b/test/weights.jl index 2030a1003..35c14eff0 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -50,9 +50,16 @@ end @test typeof(wv[[1, 3]]) === typeof(wv) @test sum( wv[[1, 3]]) === sum(w[[1, 3]]) + # Check copy + @test copy(wv) == wv + @test typeof(copy(wv)) === typeof(wv) + # Check view + @test view(wv, :) == wv + @test typeof(view(wv, :)) <: StatsBase.weightstype(typeof(wv)) @test view(wv, [1, 3]) == view(wv, [true, false, true]) == w[[1, 3]] - @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) <: StatsBase.weightstype(typeof(wv)) + @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) <: + StatsBase.weightstype(typeof(wv)) @test sum(view(wv, [1, 3])) === sum(view(wv, [true, false, true])) === sum(w[[1, 3]]) @test_throws BoundsError view(wv, [1, 5]) @test_throws BoundsError view(wv, [true, false, true, true]) @@ -61,6 +68,8 @@ end @test sum(v) == sum(view(wv, [1, 3])) v[1] -= 1 @test wv[1] == 1 + @test copy(v) == v[:] == v + @test typeof(copy(v)) === typeof(v[:]) == typeof(wv) # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val From 91208f46aced1eede5ebda881b123e1f93ca7365 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 7 Nov 2021 15:59:15 +0100 Subject: [PATCH 4/7] Revert unrelated changes --- test/pairwise.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/pairwise.jl b/test/pairwise.jl index fae07b55c..09699b276 100644 --- a/test/pairwise.jl +++ b/test/pairwise.jl @@ -39,14 +39,8 @@ arbitrary_fun(x, y) = cor(x, y) @inferred pairwise(f, x, y) - if f === cov - @test pairwise(f, [Int[]], [Int[]]) == - pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) == - reshape([0.0], 1, 1) - else - @test_throws Union{ArgumentError,MethodError} pairwise(f, [Int[]], [Int[]]) - @test_throws Union{ArgumentError,MethodError} pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) - end + @test_throws Union{ArgumentError,MethodError} pairwise(f, [Int[]], [Int[]]) + @test_throws Union{ArgumentError,MethodError} pairwise!(f, zeros(1, 1), [Int[]], [Int[]]) res = pairwise(f, [], []) @test size(res) == (0, 0) From efa81b6cd454ae6f59bb801071b9b88641236104 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 7 Nov 2021 16:19:50 +0100 Subject: [PATCH 5/7] Fixes --- src/weights.jl | 11 ++++++++--- test/weights.jl | 4 ++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index a7c75febe..b6f41362e 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -43,12 +43,17 @@ Base.copy(wv::W) where {W <: AbstractWeights} = @propagate_inbounds function Base.view(wv::W, inds...) where {S <: Real, W <: AbstractWeights{S}} @boundscheck checkbounds(wv, inds...) - @inbounds v = invoke(view, Tuple{AbstractArray, Vararg{Any}}, - wv, inds...) + @inbounds v = invoke(view, Tuple{AbstractArray, Vararg{Any}}, wv, inds...) # Sum is not actually used but compute the right type for clarity weightstype(W)(v, zero(S)) end +# This method is implemented for backward compatibility +@propagate_inbounds function Base.view(wv::W, inds::Integer) where {S <: Real, W <: AbstractWeights{S}} + @boundscheck checkbounds(wv, inds) + @inbounds invoke(view, Tuple{AbstractArray, Vararg{Any}}, wv, inds...) +end + # Always recompute the sum for views of AbstractWeights, as we cannot know whether # the parent array has been mutated Base.sum(wv::AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}) where @@ -362,7 +367,7 @@ for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights) (x.values isa SubArray || y.values isa SubArray || isequal(x.values, y.values)) && isequal(x.values, y.values) Base.:(==)(x::$w, y::$w) = - (x.values isa SubArray || y.values isa SubArray || (x.sum == y.sum)) && + (x.values isa SubArray || y.values isa SubArray || x.sum == y.sum) && (x.values == y.values) end end diff --git a/test/weights.jl b/test/weights.jl index 35c14eff0..87870b8e8 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -70,6 +70,10 @@ end @test wv[1] == 1 @test copy(v) == v[:] == v @test typeof(copy(v)) === typeof(v[:]) == typeof(wv) + @test view(wv, 1) == fill(wv[1]) + @test view(wv, 1) isa SubArray + @test view(wv, CartesianIndex(1, 1)) == fill(wv[CartesianIndex(1, 1)]) + @test view(wv, CartesianIndex(1, 1)) isa SubArray # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val From 95b3650f3891f1b1f870780c4ee9c8c0b3214697 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sun, 7 Nov 2021 16:29:50 +0100 Subject: [PATCH 6/7] More fixes --- src/weights.jl | 5 +++-- test/weights.jl | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index b6f41362e..7720f809d 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -49,8 +49,9 @@ Base.copy(wv::W) where {W <: AbstractWeights} = end # This method is implemented for backward compatibility -@propagate_inbounds function Base.view(wv::W, inds::Integer) where {S <: Real, W <: AbstractWeights{S}} - @boundscheck checkbounds(wv, inds) +@propagate_inbounds function Base.view(wv::W, inds::Union{Integer, CartesianIndex}...) where + {S <: Real, W <: AbstractWeights{S}} + @boundscheck checkbounds(wv, inds...) @inbounds invoke(view, Tuple{AbstractArray, Vararg{Any}}, wv, inds...) end diff --git a/test/weights.jl b/test/weights.jl index 87870b8e8..33c3611d1 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -70,10 +70,16 @@ end @test wv[1] == 1 @test copy(v) == v[:] == v @test typeof(copy(v)) === typeof(v[:]) == typeof(wv) + @test view(wv, [1, 2], 1) == wv[[1, 2]] + @test typeof(view(wv, [1, 2], 1)) <: StatsBase.weightstype(typeof(wv)) @test view(wv, 1) == fill(wv[1]) @test view(wv, 1) isa SubArray + @test view(wv, 1, 1) == fill(wv[1]) + @test view(wv, 1, 1) isa SubArray @test view(wv, CartesianIndex(1, 1)) == fill(wv[CartesianIndex(1, 1)]) @test view(wv, CartesianIndex(1, 1)) isa SubArray + @test view(wv, CartesianIndex(1, 1), 1) == fill(wv[CartesianIndex(1, 1), 1]) + @test view(wv, CartesianIndex(1, 1), 1) isa SubArray # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val From f9db85a287e66f78229b88c8bb96a5899c7d3669 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Thu, 11 Nov 2021 17:30:45 +0100 Subject: [PATCH 7/7] Safer approach, more tests, review fixes --- src/weights.jl | 34 +++++---- test/cov.jl | 16 ++-- test/moments.jl | 39 ++++++---- test/weights.jl | 189 ++++++++++++++++++++++++++++-------------------- 4 files changed, 166 insertions(+), 112 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 7720f809d..678b2b180 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -11,7 +11,11 @@ macro weights(name) return quote mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V} values::V - sum::S + """ + Pre-computed sum. Private field, to be accessed only via `sum(wv)`. + Set to `missing` by `view` as we cannot know when the parent is mutated. + """ + sum::Union{S, Missing} end $(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values)) $(esc(:weightstype))(::Type{<:$(esc(name))}) = $(esc(name)) @@ -19,7 +23,7 @@ macro weights(name) end length(wv::AbstractWeights) = length(wv.values) -sum(wv::AbstractWeights) = wv.sum +sum(wv::AbstractWeights{S}) where {S<:Real} = wv.sum::S isempty(wv::AbstractWeights) = isempty(wv.values) size(wv::AbstractWeights) = size(wv.values) @@ -41,11 +45,11 @@ Base.getindex(wv::AbstractWeights, ::Colon) = copy(wv) Base.copy(wv::W) where {W <: AbstractWeights} = weightstype(W)(copy(wv.values), sum(wv)) -@propagate_inbounds function Base.view(wv::W, inds...) where {S <: Real, W <: AbstractWeights{S}} +@propagate_inbounds function Base.view(wv::W, inds...) where + {S <: Real, W <: AbstractWeights{S}} @boundscheck checkbounds(wv, inds...) @inbounds v = invoke(view, Tuple{AbstractArray, Vararg{Any}}, wv, inds...) - # Sum is not actually used but compute the right type for clarity - weightstype(W)(v, zero(S)) + weightstype(W){S, eltype(wv), typeof(v)}(v, missing) end # This method is implemented for backward compatibility @@ -57,12 +61,12 @@ end # Always recompute the sum for views of AbstractWeights, as we cannot know whether # the parent array has been mutated -Base.sum(wv::AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}) where - {S<:Real, T<:Real} = sum(wv.values) +Base.sum(wv::AbstractWeights{S, T, <:SubArray}) where {S<:Real, T<:Real} = + sum(wv.values) Base.copy(wv::W) where {S<:Real, T<:Real, W<:AbstractWeights{S, T, <:SubArray{T, <:Any, <:AbstractWeights}}} = - weightstype(W)(copy(view(parent(wv.values).values, parentindices(wv.values)...)), wv.sum) + weightstype(W)(copy(view(parent(wv.values).values, parentindices(wv.values)...)), sum(wv)) @propagate_inbounds function Base.setindex!(wv::AbstractWeights, v::Real, i::Int) s = v - wv[i] @@ -112,7 +116,7 @@ if `corrected=true`. @inline function varcorrection(w::Weights, corrected::Bool=false) corrected && throw(ArgumentError("Weights type does not support bias correction: " * "use FrequencyWeights, AnalyticWeights or ProbabilityWeights if applicable.")) - 1 / w.sum + 1 / sum(w) end @weights AnalyticWeights @@ -145,7 +149,7 @@ aweights(vs::RealArray) = AnalyticWeights(vec(vs)) * `corrected=false`: ``\\frac{1}{\\sum w}`` """ @inline function varcorrection(w::AnalyticWeights, corrected::Bool=false) - s = w.sum + s = sum(w) if corrected sum_sn = sum(x -> (x / s) ^ 2, w) @@ -183,7 +187,7 @@ fweights(vs::RealArray) = FrequencyWeights(vec(vs)) * `corrected=false`: ``\\frac{1}{\\sum w}`` """ @inline function varcorrection(w::FrequencyWeights, corrected::Bool=false) - s = w.sum + s = sum(w) if corrected 1 / (s - 1) @@ -221,7 +225,7 @@ pweights(vs::RealArray) = ProbabilityWeights(vec(vs)) * `corrected=false`: ``\\frac{1}{\\sum w}`` """ @inline function varcorrection(w::ProbabilityWeights, corrected::Bool=false) - s = w.sum + s = sum(w) if corrected n = count(!iszero, w) @@ -365,11 +369,11 @@ end for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights) @eval begin Base.isequal(x::$w, y::$w) = - (x.values isa SubArray || y.values isa SubArray || isequal(x.values, y.values)) && + (x.values isa SubArray || y.values isa SubArray || isequal(x.sum, y.sum)) && isequal(x.values, y.values) Base.:(==)(x::$w, y::$w) = (x.values isa SubArray || y.values isa SubArray || x.sum == y.sum) && - (x.values == y.values) + x.values == y.values end end @@ -700,7 +704,7 @@ function quantile(v::RealVector{V}, w::AbstractWeights{W}, p::RealVector) where isempty(p) && throw(ArgumentError("empty quantile array")) all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range")) - w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero")) + sum(w) == 0 && throw(ArgumentError("weight vector cannot sum to zero")) length(v) == length(w) || throw(ArgumentError("data and weight vectors must be the same size," * "got $(length(v)) and $(length(w))")) for x in w.values diff --git a/test/cov.jl b/test/cov.jl index ab310276c..e437f909a 100644 --- a/test/cov.jl +++ b/test/cov.jl @@ -4,9 +4,15 @@ using LinearAlgebra, Random, Test struct EmptyCovarianceEstimator <: CovarianceEstimator end @testset "StatsBase.Covariance" begin + +function viewweights(f) + wv -> view(f([wv; 100]), axes(wv, 1)) +end + weight_funcs = (weights, aweights, fweights, pweights) -@testset "$f" for f in weight_funcs +@testset "$f with $viewf" for f in weight_funcs, viewf in (identity, viewweights) + fw = viewf(f) X = randn(3, 8) Z1 = X .- mean(X, dims = 1) @@ -21,8 +27,8 @@ weight_funcs = (weights, aweights, fweights, pweights) w2[1] += 1 end - wv1 = f(w1) - wv2 = f(w2) + wv1 = fw(w1) + wv2 = fw(w2) Z1w = X .- mean(X, wv1, dims=1) Z2w = X .- mean(X, wv2, dims=2) @@ -237,8 +243,8 @@ weight_funcs = (weights, aweights, fweights, pweights) end @testset "Correlation" begin - @test cor(X, f(ones(3)), 1) ≈ cor(X, dims = 1) - @test cor(X, f(ones(8)), 2) ≈ cor(X, dims = 2) + @test cor(X, fw(ones(3)), 1) ≈ cor(X, dims = 1) + @test cor(X, fw(ones(8)), 2) ≈ cor(X, dims = 2) cov1 = cov(X, wv1, 1; corrected=false) std1 = std(X, wv1, 1; corrected=false) diff --git a/test/moments.jl b/test/moments.jl index 97fda44ac..c2165c873 100644 --- a/test/moments.jl +++ b/test/moments.jl @@ -2,6 +2,11 @@ using StatsBase using Test @testset "StatsBase.Moments" begin + +function viewweights(f) + wv -> view(f([wv; 100]), axes(wv, 1)) +end + weight_funcs = (weights, aweights, fweights, pweights) ##### weighted var & std @@ -9,8 +14,9 @@ weight_funcs = (weights, aweights, fweights, pweights) x = [0.57, 0.10, 0.91, 0.72, 0.46, 0.0] w = [3.84, 2.70, 8.29, 8.91, 9.71, 0.0] -@testset "Uncorrected with $f" for f in weight_funcs - wv = f(w) +@testset "Uncorrected with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + wv = viewf(f)(w) m = mean(x, wv) # expected uncorrected output @@ -52,8 +58,9 @@ end expected_var = [NaN, 0.0694434191182236, 0.05466601256158146, 0.06628969012045285] expected_std = sqrt.(expected_var) -@testset "Corrected with $(weight_funcs[i])" for i in eachindex(weight_funcs) - wv = weight_funcs[i](w) +@testset "Corrected with $(weight_funcs[i])" for i in eachindex(weight_funcs), + viewf in (identity, viewweights) + wv = viewf(weight_funcs[i])(w) m = mean(x, wv) @testset "Variance" begin @@ -107,9 +114,10 @@ x = rand(5, 6) w1 = [0.57, 5.10, 0.91, 1.72, 0.0] w2 = [3.84, 2.70, 8.29, 8.91, 9.71, 0.0] -@testset "Uncorrected with $f" for f in weight_funcs - wv1 = f(w1) - wv2 = f(w2) +@testset "Uncorrected with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + wv1 = viewf(f)(w1) + wv2 = viewf(f)(w2) m1 = mean(x, wv1, dims=1) m2 = mean(x, wv2, dims=2) @@ -165,9 +173,10 @@ w2 = [3.84, 2.70, 8.29, 8.91, 9.71, 0.0] end end -@testset "Corrected with $f" for f in weight_funcs - wv1 = f(w1) - wv2 = f(w2) +@testset "Corrected with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + wv1 = viewf(f)(w1) + wv2 = viewf(f)(w2) m1 = mean(x, wv1, dims=1) m2 = mean(x, wv2, dims=2) @@ -241,8 +250,9 @@ end end end -@testset "Skewness and Kurtosis with $f" for f in weight_funcs - wv = f(ones(5) * 2.0) +@testset "Skewness and Kurtosis with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + wv = viewf(f)(ones(5) * 2.0) @test skewness(1:5) ≈ 0.0 @test skewness([1, 2, 3, 4, 5]) ≈ 0.0 @@ -258,7 +268,8 @@ end @test kurtosis([1, 2, 3, 4, 5], wv) ≈ -1.3 end -@testset "General Moments with $f" for f in weight_funcs +@testset "General Moments with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) x = collect(2.0:8.0) @test moment(x, 2) ≈ sum((x .- 5).^2) / length(x) @test moment(x, 3) ≈ sum((x .- 5).^3) / length(x) @@ -270,7 +281,7 @@ end @test moment(x, 4, 4.0) ≈ sum((x .- 4).^4) / length(x) @test moment(x, 5, 4.0) ≈ sum((x .- 4).^5) / length(x) - w = f([1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) + w = viewf(f)([1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) x2 = collect(2.0:6.0) @test moment(x, 2, w) ≈ sum((x2 .- 4).^2) / 5 @test moment(x, 3, w) ≈ sum((x2 .- 4).^3) / 5 diff --git a/test/weights.jl b/test/weights.jl index 33c3611d1..5be7ae2e0 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -2,6 +2,11 @@ using StatsBase using LinearAlgebra, Random, SparseArrays, Test @testset "StatsBase.Weights" begin + +function viewweights(f) + wv -> view(f([wv; 100]), axes(wv, 1)) +end + weight_funcs = (weights, aweights, fweights, pweights) ## Construction @@ -43,7 +48,7 @@ end # Check getindex & sum @test wv[1] === 1. - @test sum(wv) === 6. + @test @inferred(sum(wv)) === 6. @test wv == w @test wv[[1, 3]] == w[[1, 3]] @@ -54,33 +59,6 @@ end @test copy(wv) == wv @test typeof(copy(wv)) === typeof(wv) - # Check view - @test view(wv, :) == wv - @test typeof(view(wv, :)) <: StatsBase.weightstype(typeof(wv)) - @test view(wv, [1, 3]) == view(wv, [true, false, true]) == w[[1, 3]] - @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) <: - StatsBase.weightstype(typeof(wv)) - @test sum(view(wv, [1, 3])) === sum(view(wv, [true, false, true])) === sum(w[[1, 3]]) - @test_throws BoundsError view(wv, [1, 5]) - @test_throws BoundsError view(wv, [true, false, true, true]) - v = view(wv, [1, 3]) - wv[1] += 1 - @test sum(v) == sum(view(wv, [1, 3])) - v[1] -= 1 - @test wv[1] == 1 - @test copy(v) == v[:] == v - @test typeof(copy(v)) === typeof(v[:]) == typeof(wv) - @test view(wv, [1, 2], 1) == wv[[1, 2]] - @test typeof(view(wv, [1, 2], 1)) <: StatsBase.weightstype(typeof(wv)) - @test view(wv, 1) == fill(wv[1]) - @test view(wv, 1) isa SubArray - @test view(wv, 1, 1) == fill(wv[1]) - @test view(wv, 1, 1) isa SubArray - @test view(wv, CartesianIndex(1, 1)) == fill(wv[CartesianIndex(1, 1)]) - @test view(wv, CartesianIndex(1, 1)) isa SubArray - @test view(wv, CartesianIndex(1, 1), 1) == fill(wv[CartesianIndex(1, 1), 1]) - @test view(wv, CartesianIndex(1, 1), 1) isa SubArray - # Test setindex! success @test (wv[1] = 4) === 4 # setindex! returns original val @test wv[1] === 4. # value correctly converted and set @@ -104,14 +82,16 @@ end @test wv == [1, 2, 3] # Test state of all values end -@testset "$f, isequal and ==" for f in weight_funcs - x = f([1, 2, 3]) +@testset "$f and $viewf with isequal and ==" for f in weight_funcs, + viewf in (identity, viewweights) + fw = viewf(f) + x = fw([1, 2, 3]) - y = f([1, 2, 3]) # same values, type and parameters + y = fw([1, 2, 3]) # same values, type and parameters @test isequal(x, y) @test x == y - y = f([1.0, 2.0, 3.0]) # same values and type, different parameters + y = fw([1.0, 2.0, 3.0]) # same values and type, different parameters @test isequal(x, y) @test x == y @@ -121,17 +101,61 @@ end @test x != y end - x = f([1, 2, NaN]) # isequal and == treat NaN differently - y = f([1, 2, NaN]) + x = fw([1, 2, NaN]) # isequal and == treat NaN differently + y = fw([1, 2, NaN]) @test isequal(x, y) @test x != y - x = f([1.0, 2.0, 0.0]) # isequal and == treat ±0.0 differently - y = f([1.0, 2.0, -0.0]) + x = fw([1.0, 2.0, 0.0]) # isequal and == treat ±0.0 differently + y = fw([1.0, 2.0, -0.0]) @test !isequal(x, y) @test x == y end +@testset "view of weights" for f in weight_funcs + w = [1., 2., 3.] + wv = f(w) + + @test view(wv, :) == wv + @test sum(view(wv, :)) == sum(wv) + @test typeof(view(wv, :)) <: StatsBase.weightstype(typeof(wv)) + @test view(wv, [1, 3]) == view(wv, [true, false, true]) == w[[1, 3]] + @test isequal(view(wv, [1, 3]), view(wv, [true, false, true])) + @test isequal(view(wv, [true, false, true]), w[[1, 3]]) + @test typeof(view(wv, [1, 3])) === typeof(view(wv, [true, false, true])) <: + StatsBase.weightstype(typeof(wv)) + @test sum(view(wv, [1, 3])) === sum(view(wv, [true, false, true])) === sum(w[[1, 3]]) + + @test_throws BoundsError view(wv, [1, 5]) + @test_throws BoundsError view(wv, [true, false, true, true]) + + v = view(wv, [1, 3]) + wv[1] += 1 + @test sum(v) == sum(view(wv, [1, 3])) + v[1] -= 1 + @test wv[1] == 1 + @test copy(v) == v[:] == v + @test sum(copy(v)) == sum(v[:]) == sum(v) + @test typeof(copy(v)) === typeof(v[:]) === typeof(wv) + + @test view(wv, [1, 2], 1) == wv[[1, 2]] + @test sum(view(wv, [1, 2], 1)) == sum(wv[[1, 2]]) + @test typeof(view(wv, [1, 2], 1)) <: StatsBase.weightstype(typeof(wv)) + @test copy(view(wv, [1, 2], 1)) == view(wv, [1, 2], 1) + @test sum(copy(view(wv, [1, 2], 1))) == sum(view(wv, [1, 2])) + @test typeof(copy(view(wv, [1, 2], 1))) === typeof(copy(v)) + + @test view(wv, 1) == fill(wv[1]) + @test view(wv, 1) isa SubArray + @test view(wv, 1, 1) == fill(wv[1]) + @test view(wv, 1, 1) isa SubArray + + @test view(wv, CartesianIndex(1, 1)) == fill(wv[CartesianIndex(1, 1)]) + @test view(wv, CartesianIndex(1, 1)) isa SubArray + @test view(wv, CartesianIndex(1, 1), 1) == fill(wv[CartesianIndex(1, 1), 1]) + @test view(wv, CartesianIndex(1, 1), 1) isa SubArray +end + @testset "Unit weights" begin wv = uweights(Float64, 3) @test wv[1] === 1. @@ -271,30 +295,35 @@ end a = reshape(1.0:27.0, 3, 3, 3) -@testset "Sum $f" for f in weight_funcs - @test sum([1.0, 2.0, 3.0], f([1.0, 0.5, 0.5])) ≈ 3.5 - @test sum(1:3, f([1.0, 1.0, 0.5])) ≈ 4.5 +@testset "Sum $f and $viewf" for f in weight_funcs, viewf in (identity, viewweights) + fw = viewf(f) + @test sum([1.0, 2.0, 3.0], fw([1.0, 0.5, 0.5])) ≈ 3.5 + @test sum(1:3, fw([1.0, 1.0, 0.5])) ≈ 4.5 for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test sum(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1) - @test sum(a, f(wt), dims=2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims=2) - @test sum(a, f(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3) + @test sum(a, fw(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1) + @test sum(a, fw(wt), dims=2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims=2) + @test sum(a, fw(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3) end end -@testset "Mean $f" for f in weight_funcs - @test mean([1:3;], f([1.0, 1.0, 0.5])) ≈ 1.8 - @test mean(1:3, f([1.0, 1.0, 0.5])) ≈ 1.8 +@testset "Mean $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + fw = viewf(f) + + @test mean([1:3;], fw([1.0, 1.0, 0.5])) ≈ 1.8 + @test mean(1:3, fw([1.0, 1.0, 0.5])) ≈ 1.8 for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test mean(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) - @test mean(a, f(wt), dims=2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) - @test mean(a, f(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) - @test_throws ErrorException mean(a, f(wt), dims=4) + @test mean(a, fw(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) + @test mean(a, fw(wt), dims=2) ≈ sum(a.*reshape(wt, 1, length(wt), 1), dims=2)/sum(wt) + @test mean(a, fw(wt), dims=3) ≈ sum(a.*reshape(wt, 1, 1, length(wt)), dims=3)/sum(wt) + @test_throws ErrorException mean(a, fw(wt), dims=4) end end -@testset "Quantile fweights" begin +@testset "Quantile with fweights and $viewf" for viewf in (identity, viewweights) + fw = viewf(fweights) data = ( [7, 1, 2, 4, 10], [7, 1, 2, 4, 10], @@ -352,27 +381,29 @@ end end # quantile with fweights is the same as repeated vectors for i = 1:length(data) - @test quantile(data[i], fweights(wt[i]), p) ≈ quantile(_rep(data[i], wt[i]), p) + @test quantile(data[i], fw(wt[i]), p) ≈ quantile(_rep(data[i], wt[i]), p) end # quantile with fweights = 1 is the same as quantile for i = 1:length(data) - @test quantile(data[i], fweights(fill!(similar(wt[i]), 1)), p) ≈ quantile(data[i], p) + @test quantile(data[i], fw(fill!(similar(wt[i]), 1)), p) ≈ quantile(data[i], p) end # Issue #313 - @test quantile([1, 2, 3, 4, 5], fweights([0,1,2,1,0]), p) ≈ quantile([2, 3, 3, 4], p) - @test quantile([1, 2], fweights([1, 1]), 0.25) ≈ 1.25 - @test quantile([1, 2], fweights([2, 2]), 0.25) ≈ 1.0 + @test quantile([1, 2, 3, 4, 5], fw([0,1,2,1,0]), p) ≈ quantile([2, 3, 3, 4], p) + @test quantile([1, 2], fw([1, 1]), 0.25) ≈ 1.25 + @test quantile([1, 2], fw([2, 2]), 0.25) ≈ 1.0 # test non integer frequency weights - quantile([1, 2], fweights([1.0, 2.0]), 0.25) == quantile([1, 2], fweights([1, 2]), 0.25) - @test_throws ArgumentError quantile([1, 2], fweights([1.5, 2.0]), 0.25) + quantile([1, 2], fw([1.0, 2.0]), 0.25) == quantile([1, 2], fw([1, 2]), 0.25) + @test_throws ArgumentError quantile([1, 2], fw([1.5, 2.0]), 0.25) - @test_throws ArgumentError quantile([1, 2], fweights([1, 2]), nextfloat(1.0)) - @test_throws ArgumentError quantile([1, 2], fweights([1, 2]), prevfloat(0.0)) + @test_throws ArgumentError quantile([1, 2], fw([1, 2]), nextfloat(1.0)) + @test_throws ArgumentError quantile([1, 2], fw([1, 2]), prevfloat(0.0)) end -@testset "Quantile aweights, pweights and weights" for f in (aweights, pweights, weights) +@testset "Quantile aweights, pweights and weights" for f in (aweights, pweights, weights), + viewf in (identity, viewweights) + fw = viewf(f) data = ( [7, 1, 2, 4, 10], [7, 1, 2, 4, 10], @@ -440,22 +471,22 @@ end Random.seed!(10) for i = 1:length(data) - @test quantile(data[i], f(wt[i]), p) ≈ quantile_answers[i] atol = 1e-5 + @test quantile(data[i], fw(wt[i]), p) ≈ quantile_answers[i] atol = 1e-5 for j = 1:10 # order of p does not matter reorder = sortperm(rand(length(p))) - @test quantile(data[i], f(wt[i]), p[reorder]) ≈ quantile_answers[i][reorder] atol = 1e-5 + @test quantile(data[i], fw(wt[i]), p[reorder]) ≈ quantile_answers[i][reorder] atol = 1e-5 end for j = 1:10 # order of w does not matter reorder = sortperm(rand(length(data[i]))) - @test quantile(data[i][reorder], f(wt[i][reorder]), p) ≈ quantile_answers[i] atol = 1e-5 + @test quantile(data[i][reorder], fw(wt[i][reorder]), p) ≈ quantile_answers[i] atol = 1e-5 end end # All equal weights corresponds to base quantile for v in (1, 2, 345) for i = 1:length(data) - w = f(fill(v, length(data[i]))) + w = fw(fill(v, length(data[i]))) @test quantile(data[i], w, p) ≈ quantile(data[i], p) atol = 1e-5 for j = 1:10 prandom = rand(4) @@ -465,40 +496,42 @@ end end # test zeros are removed for i = 1:length(data) - @test quantile(vcat(1.0, data[i]), f(vcat(0.0, wt[i])), p) ≈ quantile_answers[i] atol = 1e-5 + @test quantile(vcat(1.0, data[i]), fw(vcat(0.0, wt[i])), p) ≈ quantile_answers[i] atol = 1e-5 end # Syntax v = [7, 1, 2, 4, 10] w = [1, 1/3, 1/3, 1/3, 1] answer = 6.0 - @test quantile(data[1], f(w), 0.5) ≈ answer atol = 1e-5 + @test quantile(data[1], fw(w), 0.5) ≈ answer atol = 1e-5 end -@testset "Median $f" for f in weight_funcs +@testset "Median with $f and $viewf" for f in weight_funcs, + viewf in (identity, viewweights) + fw = viewf(f) data = [4, 3, 2, 1] wt = [0, 0, 0, 0] - @test_throws ArgumentError median(data, f(wt)) - @test_throws ArgumentError median(Float64[], f(Float64[])) + @test_throws ArgumentError median(data, fw(wt)) + @test_throws ArgumentError median(Float64[], fw(Float64[])) wt = [1, 2, 3, 4, 5] - @test_throws ArgumentError median(data, f(wt)) + @test_throws ArgumentError median(data, fw(wt)) if VERSION >= v"1.0" - @test_throws MethodError median([4 3 2 1 0], f(wt)) - @test_throws MethodError median([[1 2] ; [4 5] ; [7 8] ; [10 11] ; [13 14]], f(wt)) + @test_throws MethodError median([4 3 2 1 0], fw(wt)) + @test_throws MethodError median([[1 2] ; [4 5] ; [7 8] ; [10 11] ; [13 14]], fw(wt)) end data = [1, 3, 2, NaN, 2] - @test isnan(median(data, f(wt))) + @test isnan(median(data, fw(wt))) wt = [1, 2, NaN, 4, 5] - @test_throws ArgumentError median(data, f(wt)) + @test_throws ArgumentError median(data, fw(wt)) data = [1, 3, 2, 1, 2] - @test_throws ArgumentError median(data, f(wt)) + @test_throws ArgumentError median(data, fw(wt)) wt = [-1, -1, -1, -1, -1] - @test_throws ArgumentError median(data, f(wt)) + @test_throws ArgumentError median(data, fw(wt)) wt = [-1, -1, -1, 0, 0] - @test_throws ArgumentError median(data, f(wt)) + @test_throws ArgumentError median(data, fw(wt)) data = [4, 3, 2, 1] wt = [1, 2, 3, 4] - @test median(data, f(wt)) ≈ quantile(data, f(wt), 0.5) atol = 1e-5 + @test median(data, fw(wt)) ≈ quantile(data, fw(wt), 0.5) atol = 1e-5 end @testset "Mismatched eltypes" begin