Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some interface functions to support the new Gibbs sampler in Turing #144

Closed
wants to merge 56 commits into from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Jul 12, 2024

The recent new Gibbs sampler provides a way forward for the Turing inference stack.

A near-to-medium-range goal has been to further reduce the glue code between Turing and inference packages (ref TuringLang/Turing.jl#2281). The new Gibbs implementation laid a great plan to achieve this goal.

This PR is modeled after the interface of @torfjelde's recent PR. And in some aspects, it is a rehash of #86.

(the explanation here is outdated, please refer to #144 (comment))

The goal of this PR is to determine and implement some necessary interface improvements, so that, when we update the inference packages up to the interface, they will more or less "just work" with the new Gibbs implementation.

As a first step, we test-flight two new functions recompute_logprob!!(rng, model, sampler, state) and getparams(state):

  • recompute_logprob!!(rng, model, sampler, state) recomputes the logprob given the state
  • getparams(state) extract the parameter values

Some considerations:

  • This assumes a state is implemented with AbstractMCMC compatible inference packages. And a state at least stores values of parameters from the current iteration (traditionally, this is in the form of a Transition) and logprob.
  • recompute_logprob!!(rng, model, sampler, state)
    • do we need rng?
    • should we make model into AbstractMCMC.LogDensityModel or just LogDensityProblem (and make inference packages depend on LogDensityProblems in the latter case)? This should allow inference packages to be independent from DynamicPPL, we can use getparams to construct a varinfo in Turing
  • ~~getparams(state) ~~
    • What does this function return? A vector, a transition?
    • Do we need setparams?
  • Do we also need some interface functions for state like getstats?

Tor also says (in a Slack conversation) that the a condition(model, params) is needed, but better to be implemented by packages that defines the model, which I agree.

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 12, 2024

@yebai @devmotion @cpfiffer

@devmotion
Copy link
Member

How is #86 related to this PR?

@torfjelde
Copy link
Member

Hmm, it's unclear to me whether it's worth adding these methods when they have "no use" unless some notion of conditioning is also added 😕

How is #86 related to this PR?

getparams is probably overlapping between the two PRs, but the recompute_logprob!! method is not

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 16, 2024

I am for adding a condition interface, should we upstream this from AbstractPPL?

@yebai
Copy link
Member

yebai commented Jul 16, 2024

I think AbstractPPL imports AbstractMCMC, so it is also a good idea to define condition here and then reexport from AbstractPPL.

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 18, 2024

Okay, now condition and decondition are moved to AbstractMCMC from AbstractPPL.

Do we want fix here?

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 19, 2024

@devmotion @yebai @torfjelde @mhauru a penny for your thoughts?

@yebai
Copy link
Member

yebai commented Jul 19, 2024

Do we want fix here?

I'd keep it in DynamicPPL / AbstractPPL unless there is a reason to move here.

src/AbstractMCMC.jl Outdated Show resolved Hide resolved
src/AbstractMCMC.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member

torfjelde commented Jul 19, 2024

I'm still a bit uncertain about all of this tbh. I feel like right now we're just shoving condition and decondition (which I don't think we need for Gibbs?) into AbstractMCMC.jl to motivate the inclusion of recompute_logprob!! without much thought about whether it makes sense or not 😅

I think if this is the case, then I'm preferential to ignoring my original comment of "needing condition to motivate recompute_logprob!!", i.e. just leave it as you did originally (without condition and decondition).

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 22, 2024

I removed condition (and decondition) and use the public keyword for the new interface functions. The latter will technically change the interface, so I bumped the minor version.

I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?

@torfjelde
Copy link
Member

I also think we should add something like AbstractState to normalize the design of state. This will introduce types for state everywhere, I am unsure of the impact. What's your thoughts on this?

Not for this PR at least:) If we want to discuss this, then we should open an issue and move discussion there.

@torfjelde
Copy link
Member

The latter will technically change the interface, so I bumped the minor version.

It seems you've bumped the major version, not the minor version?

Also, if we're making this part of the interface, we should probably document this?

@sunxd3
Copy link
Member Author

sunxd3 commented Jul 22, 2024

Oops, you're right.

we should probably document this?

By using the public keyword, maybe we can say "this is not official yet"? I am a little hesitate to add official documentation right now, because we don't yet have a crystal clear idea of what the interface should behave.
Will add docs.

@yebai
Copy link
Member

yebai commented Jul 23, 2024

Some high-level comments:

  • Let's introduce a setparams function to complete the getparams function.
  • Let's introduce some tests to test the interface and get a more grounded view of the design.
  • Think of an alternative name to recompute_logprob!!, which is a bit unintuitive in terms of what it means.

@sunxd3 please also take a careful look at

we want to push for merging these PRs and incorporate some nice ideas elsewhere in the ecosystem.

Copy link

codecov bot commented Sep 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.54%. Comparing base (2a77f53) to head (3ed5cb3).
Report is 5 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #144      +/-   ##
==========================================
+ Coverage   97.19%   97.54%   +0.34%     
==========================================
  Files           8        8              
  Lines         321      326       +5     
==========================================
+ Hits          312      318       +6     
+ Misses          9        8       -1     
Flag Coverage Δ
97.54% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sunxd3
Copy link
Member Author

sunxd3 commented Oct 1, 2024

@yebai @torfjelde this has been updated, many issues in the previous version has been corrected (thanks for the discussion and code review). I also added another notes in the folder design_notes, comments are welcomed. Can you give it another read?

@sunxd3 sunxd3 marked this pull request as ready for review October 1, 2024 20:19
@sunxd3
Copy link
Member Author

sunxd3 commented Oct 1, 2024

In its current form, no interface change is made to AbstractMCMC, all the interface functions are from other packages.

@sunxd3
Copy link
Member Author

sunxd3 commented Oct 3, 2024

the test error seems to be Julia 1.6-only related

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments below. Some other high-level comments:

  1. unify MCMCState with MCMCTransition by introducing an AbstractMCMCState type
  2. replace Base.vec(state) with a special state AbstractMCMC.VectorMCMCState{T}<:AbstractVector{T}, which supports getindex and setindex. All sampler packages should explicitly implement a VectorMCMCState(state) type conversion funciton.
  3. is MCMCTransition replaceable with VectorMCMCState?

design_notes/on_gibbs_implementation.md Outdated Show resolved Hide resolved
design_notes/on_gibbs_implementation.md Outdated Show resolved Hide resolved
design_notes/on_gibbs_implementation.md Outdated Show resolved Hide resolved
design_notes/on_gibbs_implementation.md Outdated Show resolved Hide resolved
design_notes/on_gibbs_implementation.md Outdated Show resolved Hide resolved
docs/src/state_interface.md Outdated Show resolved Hide resolved
This function takes the state and returns a vector of the parameter values stored in the state.

```julia
state = StateType(state::StateType, logp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
state = StateType(state::StateType, logp)
state = StateType(state::StateType, logdensity=logp)

This function takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability.

These functions provide a minimal interface to interact with the `state` datatype, which a sampler package can optionally implement.
The interface facilitates the implementation of "meta-algorithms" that combine different samplers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The interface facilitates the implementation of "meta-algorithms" that combine different samplers.
The interface facilitates the implementation of "high-order" MCMC sampling algorithms like Gibbs.

docs/src/state_interface.md Outdated Show resolved Hide resolved
@sunxd3
Copy link
Member Author

sunxd3 commented Oct 3, 2024

related reply from @devmotion TuringLang/Turing.jl#2304 (comment)

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very much in favour of using explicit functions for the interface and not overloading methods (in ways that the original methods were not intended for) 😕


Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions.

For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd much prefer an explicit method in AbstractMCMC (uncertain if we want to export it 🤷 but probably make it public). Anyone implementing this interface already has AbstractMCMC loaded, so really doesn't cost anything + avoids misuse of Base.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can resonate, issue with public is they still count as public interface, unsure if we need to make minor release

Comment on lines +51 to +53
For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability.

While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But here we're introducing kwargs, etc. which is really not a part of the LogDensityProblems.logdensity interface. It would also mean we would have to depend on LogDensityProblems.jl, which we're currently not doing (AFIAK).

Why would we do this vs. just using recompute_logp!! for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly to not make any changes to the interface, so these are just "recommendations"

I think AbstractMCMC depends on LogDensityProblems


## Proposed Interface

The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is that AbstractMCMC is a root dependency of the Turing packages, so we want to be very careful with new releases.

Fair, but if we now make a release where we assume that certain functionality is overloaded, then that seems strictly worse, no?

@torfjelde
Copy link
Member

torfjelde commented Oct 4, 2024

I do think the entire process of this would be quite a bit less painful if we did the following (I believe I've mentioned this before; if not, I apologize):

  1. Improve Add getparameters and setparameters!! #86 to a finalized form . This is useful, not just for Gibbs sampling.
  2. Make a separate package, e.g. AbstractMCMCGibbs.jl, which implements the Gibbs-only stuff, e.g. recompute_logprob!! and the sampler mapping stuff.

This is how we're doing it with MCMCTempering.jl, i.e. keep it as a separate package and slowly move pieces to AbstractMCMC.jl if it seems suitable. My problem, as stated before, is that the current Gibbs impls we're working with are really not good enough as I think is evident by a) issues that we've encountered with my Turing.jl-impl in TuringLang/Turing.jl#2328 (comment), and b) the amount of iterating you've done in this PR. This shit is complicated 😬 And I imagine it's really annoying iterating on this back and forth but without actually getting stuff merged..

So, I think a separate package would just make this entire process much easier @sunxd3 ; then we can iterate much faster on ideas (just make breaking releases), and then we can just upstream changes as we finalize things there + we can even inform about this in the official AbstractMCMC.jl docs and then people can easily support this via extensions.

@torfjelde
Copy link
Member

#85 (comment)

@sunxd3 sunxd3 marked this pull request as draft October 14, 2024 12:12
@sunxd3 sunxd3 closed this Oct 15, 2024
@yebai yebai deleted the sunxd/interface_for_gibbs branch October 22, 2024 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants