Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues with sampling #879

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 73 additions & 131 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,39 @@
using Base: mightalias

if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+
using Base: require_one_based_indexing
else
require_one_based_indexing(xs...) =
any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported"))
end

function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool)
mightalias(input, output) &&
throw(ArgumentError("destination array must not share memory with the source array"))
require_one_based_indexing(input, output)
n = length(input)
k = length(output)
if !replace && k > n
throw(DimensionMismatch("cannot draw $k samples of $n values without replacement"))
ararslan marked this conversation as resolved.
Show resolved Hide resolved
end
return (n, k)
end

function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights,
output::AbstractArray, replace::Bool)
mightalias(output, weights) &&
throw(ArgumentError("destination array must not share memory with weights array"))
_validate_sample_inputs(input, weights)
return _validate_sample_inputs(input, output, replace)
end

function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights)
require_one_based_indexing(weights)
n = length(input)
nw = length(weights)
nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw"))
return n
end

###########################################################
#
Expand All @@ -10,16 +46,15 @@ using Random: Sampler, Random.GLOBAL_RNG
### Algorithms for sampling with replacement

function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
s = Sampler(rng, 1:length(a))
n, k = _validate_sample_inputs(a, x, true)
s = Sampler(rng, 1:n)
b = a[1] - 1
if b == 0
for i = 1:length(x)
for i = 1:k
@inbounds x[i] = rand(rng, s)
end
else
for i = 1:length(x)
for i = 1:k
@inbounds x[i] = b + rand(rng, s)
end
end
Expand All @@ -36,12 +71,9 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`.
This algorithm consumes `k` random numbers.
"""
function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
s = Sampler(rng, 1:length(a))
for i = 1:length(x)
n, k = _validate_sample_inputs(a, x, true)
s = Sampler(rng, 1:n)
for i = 1:k
ararslan marked this conversation as resolved.
Show resolved Hide resolved
@inbounds x[i] = a[rand(rng, s)]
end
return x
Expand All @@ -61,11 +93,7 @@ storeindices(n, k, T) = false

# order results of a sampler that does not order automatically
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n, k = length(a), length(x)
n, k = _validate_sample_inputs(a, x, true)
# todo: if eltype(x) <: Real && eltype(a) <: Real,
# in some cases it might be faster to check
# issorted(a) to see if we can just sort x
Expand Down Expand Up @@ -140,13 +168,7 @@ memory space. Suitable for the case where memory is tight.
"""
function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
initshuffle::Bool=true)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

# initialize
for i = 1:k
Expand Down Expand Up @@ -200,13 +222,7 @@ faster than Knuth's algorithm especially when `n` is greater than `k`.
It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling
"""
function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

inds = Vector{Int}(undef, n)
for i = 1:n
Expand Down Expand Up @@ -240,13 +256,7 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase
drastically, resulting in poorer performance.
"""
function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

s = Set{Int}()
sizehint!(s, k)
Expand Down Expand Up @@ -282,13 +292,7 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`.
The outputs are ordered.
"""
function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -324,13 +328,7 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`.
The outputs are ordered.
"""
function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
n = length(a)
k = length(x)
k <= n || error("length(x) should not exceed length(a)")
n, k = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -370,13 +368,7 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`.
The outputs are ordered.
"""
function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
N = length(a)
n = length(x)
n <= N || error("length(x) should not exceed length(a)")
N, n = _validate_sample_inputs(a, x, false)

i = 0
j = 0
Expand Down Expand Up @@ -485,10 +477,7 @@ nor share memory with them, or the result may be incorrect.
"""
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
1 == firstindex(a) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
k = length(x)
n, k = _validate_sample_inputs(a, x, replace)
k == 0 && return x

if replace # with replacement
Expand All @@ -499,8 +488,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray;
end

else # without replacement
k <= n || error("Cannot draw more samples without replacement.")

if ordered
if n > 10 * k * k
seqsample_c!(rng, a, x)
Expand Down Expand Up @@ -582,8 +569,7 @@ Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.GLOBAL_RNG`).
"""
function sample(rng::AbstractRNG, wv::AbstractWeights)
1 == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
require_one_based_indexing(wv)
t = rand(rng) * sum(wv)
n = length(wv)
i = 1
Expand All @@ -596,7 +582,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights)
end
sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv)

sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)]
function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights)
_validate_sample_inputs(a, wv)
return a[sample(rng, wv)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that this line isn't tested.

end
sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv)

"""
Expand All @@ -613,15 +602,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm:
"""
function direct_sample!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
for i = 1:length(x)
_, k = _validate_sample_inputs(a, wv, x, true)
for i = 1:k
x[i] = a[sample(rng, wv)]
end
return x
Expand Down Expand Up @@ -702,14 +684,7 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers.
"""
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
n, k = _validate_sample_inputs(a, wv, x, true)

# create alias table
ap = Vector{Float64}(undef, n)
Expand All @@ -718,7 +693,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights,

# sampling
s = Sampler(rng, 1:n)
for i = 1:length(x)
for i = 1:k
j = rand(rng, s)
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]]
end
Expand All @@ -740,15 +715,8 @@ and has overall time complexity ``O(n k)``.
"""
function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
k > 0 || return x

w = Vector{Float64}(undef, n)
copyto!(w, wv)
Expand Down Expand Up @@ -786,20 +754,13 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers.
"""
function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function really part of the official API and needs checks of the arguments? IIRC I had never intended it to be called by any user directly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if users use the (IMO) intended sample API then the arguments are already checked I assume.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't have thought it was intended to be user-facing at all except, as pointed out in #876, it's included in the manual. 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's not exported, is it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, no. That said, there are implementations of three different Efraimidis-Spirakis algorithms (A, A-Res, and AExpJ), only one of which (AExpJ) is actually used internally by a function like sample. That suggests to me that there was the intention of use of these outside of the context sample but I could very well be mistaken as I don't know the history.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs were added in #254.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually summer 2016 and you reviewed it

LOL amazing. My brain runs the GC often so 7 years ago is long gone.

I definitely buy the argument that the separate, non-exported functions that each implement specific algorithms should not be considered user-facing and thus shouldn't need to perform the same kind of safety checks as those intended to be called directly. What gets me nervous is that there's nothing saying they aren't user-facing, hence issues like #876 and #877. Perhaps we could add admonitions to the docstrings, e.g.

!!! note
    This function is not intended to be called directly and is not considered
    part of the package's API.

?

A bit tangential to this discussion but in the future we could do something for sampling algorithms as is done for sorting algorithms in Base: each algorithm gets a type that subtypes some abstract sampling algorithm type then the user may select a particular algorithm via a keyword argument to sample, e.g. sample(x, wv; alg=EfraimidisAExpJ()), and internally that dispatches to use e.g. efraimidis_aexpj_wsample_norep! (after doing any appropriate argument checking 😄).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, passing types via an alg keyword argument would be the best API.

Better perform checks anyway, except if this means we run checks twice when called from sample. Is that the case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except if this means we run checks twice when called from sample. Is that the case?

Currently yes. I can add a flag to the internal checking function that makes it a no-op if called from sample but perhaps that's more complex than necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How expensive are the checks? Is there a noticeable performance difference between calling sample and the internal function?

The alg keyword argument seems a reasonable suggestion for future refactorings.

For the time being I would prefer adding a warning or note to the docstrings of these internal functions. I think it was a mistake to add them to the docs at all (also based on the initial + follow-up PRs), so I would be fine even with just removing them from the docs. They're not exported and IMO have never been part of the official API (or at least they were not supposed to be).

k > 0 || return x

# calculate keys for all items
keys = randexp(rng, n)
for i in 1:n
@inbounds keys[i] = wv.values[i]/keys[i]
@inbounds keys[i] = wv[i]/keys[i]
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

# return items with largest keys
Expand Down Expand Up @@ -827,15 +788,7 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers.
"""
function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
k > 0 || return x

# initialize priority queue
Expand All @@ -844,7 +797,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
s = 0
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w = wv[s]
devmotion marked this conversation as resolved.
Show resolved Hide resolved
w < 0 && error("Negative weight found in weight vector at index $s")
if w > 0
i += 1
Expand All @@ -859,7 +812,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
@inbounds threshold = pq[1].first

@inbounds for i in s+1:n
w = wv.values[i]
w = wv[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w > 0 || continue
key = w/randexp(rng)
Expand Down Expand Up @@ -900,15 +853,7 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
wv::AbstractWeights, x::AbstractArray;
ordered::Bool=false)
Base.mightalias(a, x) &&
throw(ArgumentError("output array x must not share memory with input array a"))
Base.mightalias(x, wv) &&
throw(ArgumentError("output array x must not share memory with weights array wv"))
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, false)
k > 0 || return x

# initialize priority queue
Expand All @@ -917,7 +862,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
s = 0
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w = wv[s]
w < 0 && error("Negative weight found in weight vector at index $s")
if w > 0
i += 1
Expand All @@ -933,7 +878,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
X = threshold*randexp(rng)

@inbounds for i in s+1:n
w = wv.values[i]
w = wv[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w > 0 || continue
X -= w
Expand Down Expand Up @@ -968,10 +913,8 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra

function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
1 == firstindex(a) == firstindex(wv) == firstindex(x) ||
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
k = length(x)
n, k = _validate_sample_inputs(a, wv, x, replace)
k > 0 || return x

if replace
if ordered
Expand All @@ -991,7 +934,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs
end
end
else
k <= n || error("Cannot draw $k samples from $n samples without replacement.")
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered)
end
return x
Expand Down
Loading