Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More accurate mean #45186

Open
wsshin opened this issue May 4, 2022 · 5 comments
Open

More accurate mean #45186

wsshin opened this issue May 4, 2022 · 5 comments
Labels
maths Mathematical functions

Comments

@wsshin
Copy link
Contributor

wsshin commented May 4, 2022

This issue tries to initiate a discussion to solve the issue reported in JuliaStats/StatsBase.jl#196.

In StatsBase.jl, we calculate the Z-scores of data by zscore(). The Z-scores are a shift-and-scaled version of data such that they have a zero mean and unit standard deviation. For example, the Z-scores of a vector x are calculated as follows:

using Statistics

μ = mean(x)
σ = std(x, mean=μ)
z = (x .- μ) ./ σ

The StatsBase.jl issue linked above reports that the Z-scores are calculated inaccurately when all the entries of x are identical. For x with all-identical entries x0, in exact arithmetic, μ should be x0 and σ should be 0, so the entries of z should be 0/0, which is NaN. However, in floating-point arithmetic, the Z-scores are not NaN due to rounding errors. An example taken from the original issue:

julia> using Statistics

julia> x0 = log(1e-5)
-11.512925464970229

julia> x = fill(x0, 8)
8-element Vector{Float64}:
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229
 -11.512925464970229

julia> μ = mean(x)  # slightly different from x0
-11.512925464970227

julia> σ = std(x)  # not 0 because μ ≠ x0
1.89900533991096e-15

julia> z = (x .- μ) ./ σ  # not NaN because μ ≠ x0 and σ ≠ 0
8-element Vector{Float64}:
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853
 -0.9354143466934853

The problem can be avoided if mean is calculated accurately when x has all-identical entries. The most obvious option is to return x[1] if x has all-identical entries. However, checking if x has all-identical entries is very slow when x is long and has indeed all-identical entries (in which case all() below does not short-circuit):

julia> x = fill(rand(), 10000);

julia> @btime mean($x);
  1.186 μs (0 allocations: 0 bytes)

julia> @btime all(y->(y==$(x[1])), $x);  # much slower than mean()
  6.311 μs (0 allocations: 0 bytes)

Another option proposed in the original issue was to refine the calculated mean. This option calculates the mean of x .- μ as the refinement value ∆μ:

function mean_refined(x)
    μ = mean(x)
    ∆μ = mean(y->(y-μ), x)
    μ += ∆μ

    return μ
end

This is 2X slower than mean() because it calls mean() twice internally, but it calculates the mean accurately for x with all-identical entries, except for very extreme cases where length(x) is in the order of 1e15 for double-precision x; see the explanation at JuliaStats/StatsBase.jl#196 (comment)). A demonstration of the successful refinement of the previous example:

julia> x0 = log(1e-5)
-11.512925464970229

julia> x = fill(x0, 8);

julia> μ = mean_refined(x)  # μ == x0
-11.512925464970229

julia> σ = std(x, mean=μ)  # σ == 0
0.0

julia> z = (x .- μ) ./ σ  # correctly NaN
8-element Vector{Float64}:
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN
 NaN

In order to minimize the chance of performance degradation, we could perform mean_refined() only when the first few entries of x are identical (as mean_refined() is meaningful only for x with nearly identical entries). Alternatively, we can introduce a keyword argument refine::Bool to mean(), such that mean() performs mean_refined() only when refine = true.

What would be the best approach to solve this issue? I note that an option to implement a more accurate sum() inside mean() was also suggested in the original issue.

@oscardssmith
Copy link
Member

I think you should use AccurateArithmetic.jl if you want a higher precision sum.

@mcabbott
Copy link
Contributor

mcabbott commented May 7, 2022

This does the same refinement with no speed cost I can detect:

function mean_refined2(x)
    α = first(x)
    mean(y->(y-α), x) + α
end

But I don't think it's more accurate in general, only on the constant case.

@oscardssmith
Copy link
Member

yeah. This doesn't generally improve numeric conditioning.

@kshyatt kshyatt added the maths Mathematical functions label May 7, 2022
@mgkuhn
Copy link
Contributor

mgkuhn commented May 10, 2022

There is also the online algorithm

function mean_refined3(x)
    m = first(x)
    for k = firstindex(x)+1:lastindex(x)
        m += (x[k] - m) / k
    end
    m
end

which can be numerically more stable. It uses only a single pass over the array, but is a bit more cumbersome to parallelize efficiently (than a simple mapreduce call).

@oscardssmith
Copy link
Member

That's exactly the same as the one mcabbott posted above. It's only more stable if all the terms are of similar magnitude and sign.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maths Mathematical functions
Projects
None yet
Development

No branches or pull requests

5 participants