-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some interface functions to support the new Gibbs sampler in Turing #144
Conversation
How is #86 related to this PR? |
Hmm, it's unclear to me whether it's worth adding these methods when they have "no use" unless some notion of conditioning is also added 😕
|
I am for adding a |
I think |
Okay, now Do we want |
@devmotion @yebai @torfjelde @mhauru a penny for your thoughts? |
I'd keep it in |
I'm still a bit uncertain about all of this tbh. I feel like right now we're just shoving I think if this is the case, then I'm preferential to ignoring my original comment of "needing |
I removed I also think we should add something like |
Not for this PR at least:) If we want to discuss this, then we should open an issue and move discussion there. |
It seems you've bumped the major version, not the minor version? Also, if we're making this part of the interface, we should probably document this? |
Oops, you're right.
By using the |
Some high-level comments:
@sunxd3 please also take a careful look at
we want to push for merging these PRs and incorporate some nice ideas elsewhere in the ecosystem. |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #144 +/- ##
==========================================
+ Coverage 97.19% 97.54% +0.34%
==========================================
Files 8 8
Lines 321 326 +5
==========================================
+ Hits 312 318 +6
+ Misses 9 8 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@yebai @torfjelde this has been updated, many issues in the previous version has been corrected (thanks for the discussion and code review). I also added another notes in the folder |
In its current form, no interface change is made to AbstractMCMC, all the interface functions are from other packages. |
the test error seems to be Julia 1.6-only related |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments below. Some other high-level comments:
- unify MCMCState with MCMCTransition by introducing an
AbstractMCMCState
type - replace
Base.vec(state)
with a special stateAbstractMCMC.VectorMCMCState{T}<:AbstractVector{T}
, which supportsgetindex
andsetindex
. All sampler packages should explicitly implement aVectorMCMCState(state)
type conversion funciton. - is
MCMCTransition
replaceable withVectorMCMCState
?
This function takes the state and returns a vector of the parameter values stored in the state. | ||
|
||
```julia | ||
state = StateType(state::StateType, logp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
state = StateType(state::StateType, logp) | |
state = StateType(state::StateType, logdensity=logp) |
This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. | ||
|
||
These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement. | ||
The interface facilitates the implementation of "meta-algorithms" that combine different samplers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface facilitates the implementation of "meta-algorithms" that combine different samplers. | |
The interface facilitates the implementation of "high-order" MCMC sampling algorithms like Gibbs. |
Co-authored-by: Hong Ge <[email protected]>
related reply from @devmotion TuringLang/Turing.jl#2304 (comment) |
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very much in favour of using explicit functions for the interface and not overloading methods (in ways that the original methods were not intended for) 😕
|
||
Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions. | ||
|
||
For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd much prefer an explicit method in AbstractMCMC (uncertain if we want to export it 🤷 but probably make it public
). Anyone implementing this interface already has AbstractMCMC loaded, so really doesn't cost anything + avoids misuse of Base
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can resonate, issue with public
is they still count as public interface, unsure if we need to make minor release
For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. | ||
|
||
While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But here we're introducing kwargs, etc. which is really not a part of the LogDensityProblems.logdensity
interface. It would also mean we would have to depend on LogDensityProblems.jl, which we're currently not doing (AFIAK).
Why would we do this vs. just using recompute_logp!!
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mainly to not make any changes to the interface, so these are just "recommendations"
I think AbstractMCMC
depends on LogDensityProblems
|
||
## Proposed Interface | ||
|
||
The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason is that
AbstractMCMC
is a root dependency of theTuring
packages, so we want to be very careful with new releases.
Fair, but if we now make a release where we assume that certain functionality is overloaded, then that seems strictly worse, no?
I do think the entire process of this would be quite a bit less painful if we did the following (I believe I've mentioned this before; if not, I apologize):
This is how we're doing it with MCMCTempering.jl, i.e. keep it as a separate package and slowly move pieces to AbstractMCMC.jl if it seems suitable. My problem, as stated before, is that the current Gibbs impls we're working with are really not good enough as I think is evident by a) issues that we've encountered with my Turing.jl-impl in TuringLang/Turing.jl#2328 (comment), and b) the amount of iterating you've done in this PR. This shit is complicated 😬 And I imagine it's really annoying iterating on this back and forth but without actually getting stuff merged.. So, I think a separate package would just make this entire process much easier @sunxd3 ; then we can iterate much faster on ideas (just make breaking releases), and then we can just upstream changes as we finalize things there + we can even inform about this in the official AbstractMCMC.jl docs and then people can easily support this via extensions. |
The recent new Gibbs sampler provides a way forward for the Turing inference stack.
A near-to-medium-range goal has been to further reduce the glue code between Turing and inference packages (ref TuringLang/Turing.jl#2281). The new Gibbs implementation laid a great plan to achieve this goal.
This PR is modeled after the interface of @torfjelde's recent PR. And in some aspects, it is a rehash of #86.
(the explanation here is outdated, please refer to #144 (comment))
The goal of this PR is to determine and implement some necessary interface improvements, so that, when we update the inference packages up to the interface, they will more or less "just work" with the new Gibbs implementation.As a first step, we test-flight two new functionsrecompute_logprob!!(rng, model, sampler, state)
andgetparams(state)
:recompute_logprob!!(rng, model, sampler, state)
recomputes the logprob given thestate
getparams(state)
extract the parameter valuesSome considerations:This assumes astate
is implemented withAbstractMCMC
compatible inference packages. And astate
at least stores values of parameters from the current iteration (traditionally, this is in the form of aTransition
) and logprob.recompute_logprob!!(rng, model, sampler, state)
do we needrng
?should we makemodel
intoAbstractMCMC.LogDensityModel
or justLogDensityProblem
(and make inference packages depend onLogDensityProblems
in the latter case)? This should allow inference packages to be independent from DynamicPPL, we can usegetparams
to construct avarinfo
in Turinggetparams(state)
~~What does this function return? A vector, atransition
?Do we needsetparams
?Do we also need some interface functions forstate
likegetstats
?Tor also says (in a Slack conversation) that the acondition(model, params)
is needed, but better to be implemented by packages that defines themodel
, which I agree.