-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for user-supplied RNG state in all interfaces #520
base: modular-rng
Are you sure you want to change the base?
Conversation
Updates the Gen function interface to use the standard Julia pattern for user-supplied RNG states, i.e: myfunc(args...) = myfunc(default_rng(), args...) myfunc(rng::AbstractRNG, args...) = ... This is applied to all function interfaces which use rng. Inference algorithms provide instead a keyword argument `rng` which tends to be more common for higher level function interfaces.
Hi @bgroenks96, thank you for this thorough PR. I support merging these improvements to the API. Could you please take a look at the failing ContinuousIntegration tests? There are appears to be a "Not implemented" error in one of the tests. |
@fsaad Fixed in the last two commits. All tests are passing for me locally. |
Regarding the fix in e39459b, I am not sure that I fully understand this part of the API, but I am assuming by the name that deterministic functions should not need access to the RNG. If this is wrong, then we would unfortunately need to break this part of the API, I think. |
Thanks for this PR! I'll review this in the next couple of days, but intuitively this seems like the right way to modify the interfaces and I've checked that other implementations of Gen (e.g. the work-in-progress JAX implementation) also adopt a similar interface for control over the RNG / RNG seed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making this PR! I've left comments in most of the places where changes are needed. Here's a summary of the changes that still need to be made:
- Modify
GF...State
types to become parametric types - For the dynamic DSl, make sure in
traceat
calls that the RNG is passed down to recursive calls to GFI methods likesimulate
,generate
, etc. - Add a fallback implementation (with a warning) from the RNG-version of the GFI methods to the non-RNG version of the GFI methods. This is to prevent breaking existing code outside of Gen.jl that does not use the
non-RNG
version (same goes for the definition ofrandom
). - Update the GFI docstrings or documentation to mention that a custom RNG can be provided by the user.
- For the static modeling language, ensure that the RNG is passed down to nested GFI calls for all GFI methods.
- For the static modeling language, replace the
rng
variable name with a globally gen-symed variable name, to avoid name collisions.
In addition, it appears that these parts of the Gen.jl library need to be updated to make use of custom RNGs:
- The rest of the combinators, like
Map
,Unfold
, andCallAt
, - Most of the inference library:
- Importance (re)sampling
- Particle filtering
- Trace translators
- Elliptical slice sampling
- Trace kernel DSL, since it allows users to write code that randomly decides between MCMC kernels.
I understand that this is a fair amount of work, and that we should possibly break it up into separate PRs. One way to do this might be for us to create a separate branch of Gen dedicated to merging this broader set of changes, and then each PR can focus on supporting custom RNGs for various portions of the code base. If that sounds good to you, I can go ahead and create a branch called modular-rng
(as you've called your branch).
src/dynamic/dynamic.jl
Outdated
@@ -47,10 +47,13 @@ accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad | |||
|
|||
mutable struct GFUntracedState | |||
params::Dict{Symbol,Any} | |||
rng::AbstractRNG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should change all these GF...State
structs to be parametric in the type of the RNG, to avoid potential performance regressions due to type instability. To be specific, I would replace this with:
mutable struct GFUntracedState{R <: AbstractRNG}
params::Dict{Symbol, Any}
rng::R
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/dynamic/dynamic.jl
Outdated
@@ -85,7 +88,7 @@ end | |||
gen_fn(args...) | |||
|
|||
@inline traceat(state::GFUntracedState, dist::Distribution, args, key) = | |||
random(dist, args...) | |||
random(state.rng, dist, args...) | |||
|
|||
@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) = | |||
gen_fn(args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splice
should also pass state.rng
as the first argument to gen_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/dynamic/dynamic.jl
Outdated
@@ -85,7 +88,7 @@ end | |||
gen_fn(args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should pass state.rng
as the first argument to gen_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, | |||
retval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On line 59, the recursive call to generate
needs to pass state.rng
to the callee function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple) | |||
retval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On line 40, state.rng
needs to be passed to the recursive call to propose
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -59,29 +59,35 @@ function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap | |||
(weight, value) | |||
end | |||
|
|||
function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K} | |||
propose(gen_fn::ChoiceAtCombinator, args::Tuple) = propose(default_rng(), gen_fn, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the CallAtCombinator
in call_at.jl
also needs to be updated to pass down rng
to any nested generative function calls. Same goes for Map
, Unfold
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -131,8 +131,9 @@ end | |||
# TODO | |||
accepts_output_grad(::Recurse) = false | |||
|
|||
function (gen_fn::Recurse)(args...) | |||
(_, _, retval) = propose(gen_fn, args) | |||
(gen_fn::Recurse)(args...) = gen_fn(default_rng(), args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make the same changes for Map
and Unfold
(which are much more widely used combinators).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
Calls `random` with the default global RNG. | ||
""" | ||
random(dist::Distribution, args...) = random(default_rng(), dist, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For similar reasons to the comments I made on the definition of simulate
, I believe we also need a fallback in the other direction, along with a warning. There's a fair amount of custom distributions that people have written with Gen (see e.g. https://github.com/probcomp/GenDistributions.jl), and we need to add a fallback from a version with the RNG to the version without the RNG.
Also, I don't think we need another docstring for this definition of random
since this random
is already documented above. The original docstring should just be modified to note that the user can supply a custom RNG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/static_ir/simulate.jl
Outdated
@@ -25,7 +25,7 @@ function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options) | |||
incr = gensym("logpdf") | |||
addr = QuoteNode(node.addr) | |||
dist = QuoteNode(node.dist) | |||
push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...)))) | |||
push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))(rng, $dist, $(args...)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to this change, we need to make sure that recursive calls to simulate
also pass along the RNG. e.g. on line 43 of this file, there is a recursive call to simulate that should be passed rng
as the first argument.
This should be done for all of the GFI functions.
Also, I'll make this point again later below, but to be safe and avoid name collisions, I believe rng
here should be replaced with a globally gen-symed variable name called STATIC_RNG
. Otherwise, if the user happens to define their own variable called rng
in their function definition, the generated code may end up being buggy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/static_ir/static_ir.jl
Outdated
@@ -63,18 +63,18 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati | |||
$(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) | |||
# Generate GFI definitions | |||
(gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] | |||
@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple))) | |||
@generated function $(GlobalRef(Gen, :simulate))(rng::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted above, rather than calling the variable rng
, I believe it will be safer to add this global definition somewhere near the top of this file:
"Global reference to the RNG variable for the static modeling language."
const STATIC_RNG = gensym("rng")
And then change the above line to:
@generated function $(GlobalRef(Gen, :simulate))($STATIC_RNG::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)))
$STATIC_RNG
should then be used whenever generating code that needs some reference to the RNG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed review, @ztangent . I was afraid this would turn out to be more complicated than it initially looked... I will try to go through the comments later this week. In the meantime, it would be good to go ahead and create that branch, and then I can change the PR to target this instead of |
Hi @ztangent, I apologize for the long delay. I had some other more pressing deadlines to attend to. I think that I have addressed your first set of comments, as well as the issues with the combinators. I still need to look more closely at the inference algorithms. Please let me know if I have missed anything or if I did not fully address any of the issues. EDIT: Note that I have verified that all tests are passing on my machine (as of 0539d93). |
Awesome, thank you! I should have time to look more at this the week after next. I've also created the |
Updates the Gen function interface to use the standard Julia pattern for user-supplied RNG states, i.e:
This is applied to all function interfaces which use rng.
Inference algorithms provide instead a keyword argument
rng
which tends to be more common for higher level function interfaces.Note that this PR should be fully backwards compatible with all tests and existing Gen code since method dispatches with
default_rng()
are universally provided.Resolves #33