diff --git a/src/distributions/bool.jl b/src/distributions/bool.jl index 7399d5b..5114645 100644 --- a/src/distributions/bool.jl +++ b/src/distributions/bool.jl @@ -11,7 +11,7 @@ end pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p -rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p +rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p Base.iterate(d::BoolDistribution) = ((d.p, true), true) function Base.iterate(d::BoolDistribution, state::Bool) diff --git a/src/distributions/deterministic.jl b/src/distributions/deterministic.jl index cdcb205..d044397 100644 --- a/src/distributions/deterministic.jl +++ b/src/distributions/deterministic.jl @@ -9,8 +9,7 @@ struct Deterministic{T} val::T end -rand(rng::AbstractRNG, d::Deterministic) = d.val -rand(d::Deterministic) = d.val +rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:Deterministic}) = s[].val support(d::Deterministic) = (d.val,) sampletype(::Type{Deterministic{T}}) where T = T Random.gentype(::Type{Deterministic{T}}) where T = T diff --git a/src/distributions/sparse_cat.jl b/src/distributions/sparse_cat.jl index fc80287..29c9960 100644 --- a/src/distributions/sparse_cat.jl +++ b/src/distributions/sparse_cat.jl @@ -12,7 +12,8 @@ struct SparseCat{V, P} probs::P end -function rand(rng::AbstractRNG, d::SparseCat) +function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) + d = s[] r = sum(d.probs)*rand(rng) tot = zero(eltype(d.probs)) for (v, p) in d diff --git a/src/distributions/uniform.jl b/src/distributions/uniform.jl index 5c56473..44e464e 100644 --- a/src/distributions/uniform.jl +++ b/src/distributions/uniform.jl @@ -56,7 +56,7 @@ Random.gentype(::Type{UnsafeUniform{T}}) where T = eltype(T) const Unif = Union{Uniform,UnsafeUniform} -rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d)) +rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:Unif}) = rand(rng, support(s[])) mean(d::Unif) = mean(support(d)) mode(d::Unif) = mode(support(d)) diff --git a/test/test_bool.jl b/test/test_bool.jl index d151947..60a3a69 100644 --- a/test/test_bool.jl +++ b/test/test_bool.jl @@ -15,5 +15,8 @@ let # testing hash @test hash(d) == hash(d2) + @test rand(d) in [true, false] + @test all(x in [true, false] for x in rand(d,2)) + @test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="BoolDistribution"), d) end diff --git a/test/test_deterministic.jl b/test/test_deterministic.jl index de2c2ec..1f1c5a7 100644 --- a/test/test_deterministic.jl +++ b/test/test_deterministic.jl @@ -1,7 +1,9 @@ d = Deterministic(1) @test rand(d) == 1 +@test all(rand(d, 2) .== 1) @test rand(MersenneTwister(4), d) == 1 +@test all(rand(MersenneTwister(4), d, 2) .== 1) @test collect(support(d)) == [1] @test Random.gentype(d) == typeof(1) @test Random.gentype(typeof(d)) == typeof(1) diff --git a/test/test_implicit.jl b/test/test_implicit.jl index b93b0b8..7513186 100644 --- a/test/test_implicit.jl +++ b/test/test_implicit.jl @@ -10,6 +10,7 @@ m = IMDP() td = transition(m, 1.0, 1) @test 2 <= rand(td) <= 3 +@test all(2 <= sp <= 3 for sp in rand(td, 2)) impldist(m) = ImplicitDistribution(m) do m, rng return rand(rng, m) diff --git a/test/test_sparse_cat.jl b/test/test_sparse_cat.jl index 9d6cd81..172af69 100644 --- a/test/test_sparse_cat.jl +++ b/test/test_sparse_cat.jl @@ -42,6 +42,13 @@ let @test isapprox(count(samples.==:c)/N, pdf(d,:c), atol=0.005) @test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005) + # rand(::SparseCat, ::Integer) + samples = rand(d, N) + @test isapprox(count(samples.==:a)/N, pdf(d,:a), atol=0.005) + @test isapprox(count(samples.==:b)/N, pdf(d,:b), atol=0.005) + @test isapprox(count(samples.==:c)/N, pdf(d,:c), atol=0.005) + @test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005) + @test_throws ErrorException rand(Random.GLOBAL_RNG, SparseCat([1], [0.0])) @test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="SparseCat distribution"), d) diff --git a/test/test_uniform.jl b/test/test_uniform.jl index df39602..8e382c2 100644 --- a/test/test_uniform.jl +++ b/test/test_uniform.jl @@ -2,6 +2,7 @@ d = Uniform([1]) @test rand(d) == 1 @test rand(MersenneTwister(4), d) == 1 +@test all(rand(MersenneTwister(4), d, 2) .== 1) @test collect(support(d)) == [1] @test Random.gentype(d) == typeof(1) @test Random.gentype(typeof(d)) == typeof(1) @@ -16,6 +17,7 @@ d = Uniform([1]) d2 = Uniform((:symbol,)) @test rand(d2) == :symbol +@test all(rand(d2, 2) .== :symbol) @test rand(MersenneTwister(4), d2) == :symbol @test collect(support(d2)) == [:symbol] @test Random.gentype(d2) == typeof(:symbol) @@ -34,6 +36,7 @@ d3 = UnsafeUniform([1]) @test rand(d3) == 1 @test rand(MersenneTwister(4), d3) == 1 +@test all(rand(MersenneTwister(4), d3, 2) .== 1) @test collect(support(d3)) == [1] @test Random.gentype(d3) == typeof(1) @test Random.gentype(typeof(d3)) == typeof(1) @@ -48,6 +51,7 @@ d3 = UnsafeUniform([1]) d4 = UnsafeUniform((:symbol,)) @test rand(d4) == :symbol @test rand(MersenneTwister(4), d4) == :symbol +@test all(rand(MersenneTwister(4), d4, 2) .== :symbol) @test collect(support(d4)) == [:symbol] @test Random.gentype(d4) == typeof(:symbol) @test Random.gentype(typeof(d4)) == typeof(:symbol)