Skip to content

Commit

Permalink
Merge pull request #59 from JuliaPOMDP/distro-update
Browse files Browse the repository at this point in the history
fix #58 and provide recommended methods of rand for all distributions
  • Loading branch information
zsunberg authored Nov 12, 2021
2 parents c8f7437 + a4ccf01 commit 94b460d
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/distributions/bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end

pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p

rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p
rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p

Base.iterate(d::BoolDistribution) = ((d.p, true), true)
function Base.iterate(d::BoolDistribution, state::Bool)
Expand Down
3 changes: 1 addition & 2 deletions src/distributions/deterministic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ struct Deterministic{T}
val::T
end

rand(rng::AbstractRNG, d::Deterministic) = d.val
rand(d::Deterministic) = d.val
rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:Deterministic}) = s[].val
support(d::Deterministic) = (d.val,)
sampletype(::Type{Deterministic{T}}) where T = T
Random.gentype(::Type{Deterministic{T}}) where T = T
Expand Down
3 changes: 2 additions & 1 deletion src/distributions/sparse_cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ struct SparseCat{V, P}
probs::P
end

function rand(rng::AbstractRNG, d::SparseCat)
function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat})
d = s[]
r = sum(d.probs)*rand(rng)
tot = zero(eltype(d.probs))
for (v, p) in d
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Random.gentype(::Type{UnsafeUniform{T}}) where T = eltype(T)

const Unif = Union{Uniform,UnsafeUniform}

rand(rng::AbstractRNG, d::Unif) = rand(rng, support(d))
rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:Unif}) = rand(rng, support(s[]))
mean(d::Unif) = mean(support(d))
mode(d::Unif) = mode(support(d))

Expand Down
3 changes: 3 additions & 0 deletions test/test_bool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@ let
# testing hash
@test hash(d) == hash(d2)

@test rand(d) in [true, false]
@test all(x in [true, false] for x in rand(d,2))

@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="BoolDistribution"), d)
end
2 changes: 2 additions & 0 deletions test/test_deterministic.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
d = Deterministic(1)

@test rand(d) == 1
@test all(rand(d, 2) .== 1)
@test rand(MersenneTwister(4), d) == 1
@test all(rand(MersenneTwister(4), d, 2) .== 1)
@test collect(support(d)) == [1]
@test Random.gentype(d) == typeof(1)
@test Random.gentype(typeof(d)) == typeof(1)
Expand Down
1 change: 1 addition & 0 deletions test/test_implicit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ m = IMDP()

td = transition(m, 1.0, 1)
@test 2 <= rand(td) <= 3
@test all(2 <= sp <= 3 for sp in rand(td, 2))

impldist(m) = ImplicitDistribution(m) do m, rng
return rand(rng, m)
Expand Down
7 changes: 7 additions & 0 deletions test/test_sparse_cat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ let
@test isapprox(count(samples.==:c)/N, pdf(d,:c), atol=0.005)
@test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005)

# rand(::SparseCat, ::Integer)
samples = rand(d, N)
@test isapprox(count(samples.==:a)/N, pdf(d,:a), atol=0.005)
@test isapprox(count(samples.==:b)/N, pdf(d,:b), atol=0.005)
@test isapprox(count(samples.==:c)/N, pdf(d,:c), atol=0.005)
@test isapprox(count(samples.==:d)/N, pdf(d,:d), atol=0.005)

@test_throws ErrorException rand(Random.GLOBAL_RNG, SparseCat([1], [0.0]))

@test sprint((io,d)->show(io,MIME("text/plain"),d), d) == sprint((io,d)->showdistribution(io,d,title="SparseCat distribution"), d)
Expand Down
4 changes: 4 additions & 0 deletions test/test_uniform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ d = Uniform([1])

@test rand(d) == 1
@test rand(MersenneTwister(4), d) == 1
@test all(rand(MersenneTwister(4), d, 2) .== 1)
@test collect(support(d)) == [1]
@test Random.gentype(d) == typeof(1)
@test Random.gentype(typeof(d)) == typeof(1)
Expand All @@ -16,6 +17,7 @@ d = Uniform([1])

d2 = Uniform((:symbol,))
@test rand(d2) == :symbol
@test all(rand(d2, 2) .== :symbol)
@test rand(MersenneTwister(4), d2) == :symbol
@test collect(support(d2)) == [:symbol]
@test Random.gentype(d2) == typeof(:symbol)
Expand All @@ -34,6 +36,7 @@ d3 = UnsafeUniform([1])

@test rand(d3) == 1
@test rand(MersenneTwister(4), d3) == 1
@test all(rand(MersenneTwister(4), d3, 2) .== 1)
@test collect(support(d3)) == [1]
@test Random.gentype(d3) == typeof(1)
@test Random.gentype(typeof(d3)) == typeof(1)
Expand All @@ -48,6 +51,7 @@ d3 = UnsafeUniform([1])
d4 = UnsafeUniform((:symbol,))
@test rand(d4) == :symbol
@test rand(MersenneTwister(4), d4) == :symbol
@test all(rand(MersenneTwister(4), d4, 2) .== :symbol)
@test collect(support(d4)) == [:symbol]
@test Random.gentype(d4) == typeof(:symbol)
@test Random.gentype(typeof(d4)) == typeof(:symbol)
Expand Down

0 comments on commit 94b460d

Please sign in to comment.