Skip to content

Commit

Permalink
Backport 3-arg rand to AbstractPPL 0.5.x (#80)
Browse files Browse the repository at this point in the history
* Back-port 3-arg `rand` interface.

* Bump version for interface change.

* Update CI.yml

* Fix test errors.

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
sunxd3 and yebai authored Feb 25, 2023
1 parent d6d898b commit 7d78781
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ on:
- trying
# Build the main branch.
- main
pull_request:
branches:
- main
- releases-0.5.x

jobs:
test:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probablistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.5.3"
version = "0.5.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
20 changes: 20 additions & 0 deletions src/abstractprobprog.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AbstractMCMC
using DensityInterface
using Random


"""
Expand Down Expand Up @@ -60,3 +61,22 @@ m = decondition(condition(m, obs))
should hold for generative models `m` and arbitrary `obs`.
"""
function condition end


"""
rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T
Draw a sample from the joint distribution of the model specified by the probabilistic program.
The sample will be returned as format specified by `T`.
"""
Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram)
function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram)
return rand(rng, NamedTuple, model)
end
function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T}
return rand(Random.default_rng(), T, model)
end
function Base.rand(model::AbstractProbabilisticProgram)
return rand(Random.default_rng(), NamedTuple, model)
end
10 changes: 7 additions & 3 deletions src/graphinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ function Random.rand!(m::AbstractPPL.GraphPPL.Model{T}) where T
end

"""
rand!(rng::AbstractRNG, m::Model)
rand(m::Model)
Draw random samples from the model and mutate the node values.
Draw random samples from the model and return the samples as NamedTuple.
# Examples
Expand All @@ -470,11 +470,15 @@ julia> rand(m)
(μ = 1.0, s2 = 1.0907695400401212, y = 0.05821954440386368)
```
"""
function Random.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind}
function Base.rand(rng::AbstractRNG, sm::Random.SamplerTrivial{Model{Tnames, Tinput, Tvalue, Teval, Tkind}}) where {Tnames, Tinput, Tvalue, Teval, Tkind}
m = deepcopy(sm[])
get_model_values(rand!(rng, m))
end

function Base.rand(rng::AbstractRNG, ::Type{NamedTuple}, m::Model)
rand(rng, Random.SamplerTrivial(m))
end

"""
logdensityof(m::Model)
Expand Down
6 changes: 4 additions & 2 deletions src/varname.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,10 @@ resolved as `VarName` only supports non-dynamic indexing as determined by
```jldoctest
julia> # Dynamic indexing is not allowed in `VarName`
@varname(x[end])
ERROR: UndefVarError: x not defined
[...]
ERROR: UndefVarError: `x` not defined
Stacktrace:
[1] top-level scope
@ none:1
julia> # To be able to resolve `end` we need `x` to be available.
x = randn(2); @varname(x[end])
Expand Down
37 changes: 37 additions & 0 deletions test/abstractprobprog.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using AbstractPPL
using Random
using Test

mutable struct RandModel <: AbstractProbabilisticProgram
rng
T
end

function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::RandModel) where {T}
model.rng = rng
model.T = T
return nothing
end

@testset "AbstractProbabilisticProgram" begin
@testset "rand defaults" begin
model = RandModel(nothing, nothing)
rand(model)
@test model.rng == Random.default_rng()
@test model.T === NamedTuple
rngs = [Random.default_rng(), Random.MersenneTwister(42)]
Ts = [NamedTuple, Dict]
@testset for T in Ts
model = RandModel(nothing, nothing)
rand(T, model)
@test model.rng == Random.default_rng()
@test model.T === T
end
@testset for rng in rngs
model = RandModel(nothing, nothing)
rand(rng, model)
@test model.rng === rng
@test model.T === NamedTuple
end
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using Test

@testset "AbstractPPL.jl" begin
include("deprecations.jl")
include("abstractprobprog.jl")
include("graphinfo/graphinfo.jl")
@testset "doctests" begin
DocMeta.setdocmeta!(
Expand All @@ -20,6 +21,6 @@ using Test
:(using AbstractPPL);
recursive=true,
)
doctest(AbstractPPL; manual=false)
doctest(AbstractPPL; manual=false, fix=true)
end
end

2 comments on commit 7d78781

@yebai
Copy link
Member

@yebai yebai commented on 7d78781 Feb 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunxd3, you can make a new release by commenting:

@JuliaRegistrator register

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/78510

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.4 -m "<description of version>" 7d78781ebcd3dbe6dbab22c52657bf773da4fc43
git push origin v0.5.4

Please sign in to comment.