diff --git a/src/sampling.jl b/src/sampling.jl index 609c7d48b..09c314a27 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -102,6 +102,29 @@ sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, sampler!(rng, a, wv, x) end +""" +This algorithm generates a sorted sample with replacement by +adapting the classic result that the cumulative sum of n+1 +exponentially-distributed random numbers divided by the overall sum +(dropping the last) is a sorted sample from a uniform[0,1] +""" +function uniform_orderstat_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) + 1 == firstindex(a) == firstindex(x) || + throw(ArgumentError("non 1-based arrays are not supported")) + Base.mightalias(a, x) && + throw(ArgumentError("output array x must not share memory with input array a")) + n = length(a) + k = length(x) + exp_rands = randexp(rng, k) + sorted_rands = cumsum(exp_rands) + cum_step = (sorted_rands[end] + randexp(rng)) / n + @inbounds for i in eachindex(x) + j = ceil(Int, sorted_rands[i] / cum_step) + x[i] = a[j] + end + return x +end + ### draw a pair of distinct integers in [1:n] """ @@ -500,7 +523,11 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; if replace # with replacement if ordered - sample_ordered!(direct_sample!, rng, a, x) + if k <= 10 + sample_ordered!(direct_sample!, rng, a, x) + else + uniform_orderstat_sample!(rng, a, x) + end else direct_sample!(rng, a, x) end diff --git a/test/sampling.jl b/test/sampling.jl index 34b339dec..5630bdcc1 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -74,14 +74,19 @@ test_rng_use(direct_sample!, 1:10, zeros(Int, 6)) a = sample(3:12, n) check_sample_wrep(a, (3, 12), 5.0e-3; ordered=false) +rng = StableRNG(1) for rev in (true, false), T in (Int, Int16, Float64, Float16, BigInt, ComplexF64, Rational{Int}) r = rev ? reverse(3:12) : (3:12) r = T===Int ? r : T.(r) - aa = Int.(sample(r, n; ordered=true)) + aa = Int.(sample(rng, r, n; ordered=true)) check_sample_wrep(aa, (3, 12), 5.0e-3; ordered=true, rev=rev) - aa = Int.(sample(r, 10; ordered=true)) - check_sample_wrep(aa, (3, 12), 0; ordered=true, rev=rev) + aa = Int[] + for i in 1:Int(n/10) + bb = Int.(sample(rng, r, 10; ordered=true)) + append!(aa, bb) + end + check_sample_wrep(sort!(aa, rev=rev), (3, 12), 5.0e-3; ordered=true, rev=rev) end @test StatsBase._storeindices(1, 1, BigFloat) == StatsBase._storeindices(1, 1, BigFloat) == false