Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of step_warmup #117

Merged
merged 42 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0987f5f
added step_warmup which is can be overloaded when convenient
torfjelde Mar 9, 2023
30c9f12
added step_warmup to docs
torfjelde Mar 9, 2023
7faa73f
Update src/interface.jl
torfjelde Mar 9, 2023
bd0bdc7
introduce new kwarg `num_warmup` to `sample` which uses `step_warmup`
torfjelde Mar 10, 2023
c620cca
updated docs
torfjelde Mar 10, 2023
572a286
allow combination of discard_initial and num_warmup
torfjelde Mar 10, 2023
6b842ee
added docstring for mcmcsample
torfjelde Mar 10, 2023
ca03832
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
0441773
Apply suggestions from code review
torfjelde Mar 10, 2023
ea369ff
Apply suggestions from code review
torfjelde Mar 10, 2023
8e0ca53
Update src/sample.jl
torfjelde Mar 10, 2023
6877978
removed docstring and deferred description of keyword arguments to th…
torfjelde Mar 10, 2023
b3b3148
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ddc5254
Update src/sample.jl
torfjelde Mar 10, 2023
ffbd32f
Update src/sample.jl
torfjelde Mar 10, 2023
87480ff
added num_warmup to common keyword arguments docs
torfjelde Mar 10, 2023
76f2f23
also allow step_warmup for the initial step
torfjelde Mar 10, 2023
c00d0c9
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ef09c19
simplify logic for discarding fffinitial samples
torfjelde Mar 10, 2023
49b8406
Apply suggestions from code review
torfjelde Mar 10, 2023
f005746
also report progress for the discarded samples
torfjelde Mar 10, 2023
9dccd8a
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ff00e6e
Apply suggestions from code review
torfjelde Mar 10, 2023
7ce9f6b
move progress-report to end of for-loop for discard samples
torfjelde Mar 10, 2023
3a217b2
move step_warmup to the inner while loops too
torfjelde Mar 13, 2023
de9bb2c
Update src/sample.jl
torfjelde Mar 13, 2023
85d938f
Apply suggestions from code review
torfjelde Apr 19, 2023
0a667a4
reverted to for-loop
torfjelde Apr 19, 2023
91f5a10
Update src/sample.jl
torfjelde Apr 19, 2023
7603171
added accidentanly removed comment
torfjelde Apr 19, 2023
ef68d04
Update src/sample.jl
torfjelde Apr 19, 2023
25afc66
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 24, 2023
1886fa8
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 25, 2023
0ea293a
fixed formatting
torfjelde Oct 26, 2023
6e8f88e
fix typo
torfjelde Oct 26, 2023
44c55bb
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 4, 2024
3b4f6db
Apply suggestions from code review
torfjelde Oct 4, 2024
f9142a6
Added testing of warmup steps
torfjelde Oct 4, 2024
295fdc1
Added checks as @devmotion requested
torfjelde Oct 4, 2024
e6acb1f
Removed unintended change in previous commit
torfjelde Oct 4, 2024
2e9fa5c
Bumped patch version
torfjelde Oct 4, 2024
366fceb
Bump minor version instead of patch version since this is a new feature
torfjelde Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.3.0"
version = "5.4.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
10 changes: 7 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,13 @@ Common keyword arguments for regular and parallel sampling are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
- `discard_initial` (default: `0`): number of initial samples that are discarded
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step,
i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to
[`AbstractMCMC.step`](@ref).
- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that
if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples.
- `thinning` (default: `1`): factor by which to thin samples.
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
is passed `initial_state` as the `state` argument.
Expand Down
9 changes: 9 additions & 0 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ the sampling step of the inference method.
AbstractMCMC.step
```

If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading

```@docs
AbstractMCMC.step_warmup
```

which will be used for the first `num_warmup` iterations, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref).
Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above.

## Collecting samples

!!! note
Expand Down
17 changes: 17 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ current `state` of the sampler.
"""
function step end

"""
step_warmup(rng, model, sampler[, state; kwargs...])

Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`.

When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.step`](@ref) in the first
`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref).
This is useful if the sampler has an initial "warmup"-stage that is different from the
standard iteration.

By default, this simply calls [`AbstractMCMC.step`](@ref).
"""
step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...)
function step_warmup(rng, model, sampler, state; kwargs...)
return step(rng, model, sampler, state; kwargs...)
end

"""
samples(sample, model, sampler[, N; kwargs...])

Expand Down
132 changes: 102 additions & 30 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ isdone(rng, model, sampler, samples, state, iteration; kwargs...)
```
where `state` and `iteration` are the current state and iteration of the sampler, respectively.
It should return `true` when sampling should end, and `false` otherwise.

# Keyword arguments

See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
arguments.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
Expand Down Expand Up @@ -80,6 +85,11 @@ end

Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel
using the `parallel` algorithm, and combine them into a single chain.

# Keyword arguments

See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
arguments.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
Expand All @@ -94,7 +104,6 @@ function StatsBase.sample(
end

# Default implementations of regular and parallel sampling.

function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
Expand All @@ -103,15 +112,28 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Sampling",
callback=nothing,
discard_initial=0,
num_warmup::Int=0,
discard_initial::Int=num_warmup,
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
discard_initial >= 0 ||
throw(ArgumentError("number of discarded samples must be non-negative"))
num_warmup >= 0 ||
throw(ArgumentError("number of warm-up samples must be non-negative"))
Ntotal = thinning * (N - 1) + discard_initial + 1
Ntotal >= num_warmup || throw(
ArgumentError("number of warm-up samples exceeds the total number of samples")
)

# Determine how many samples to drop from `num_warmup` and the
# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup

# Start the timer
start = time()
Expand All @@ -126,22 +148,41 @@ function mcmcsample(
end

# Obtain the initial sample and state.
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
sample, state = if num_warmup > 0
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
step(rng, model, sampler, initial_state; kwargs...)
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Update the progress bar.
itotal = 1
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

# Discard initial samples.
for i in 1:discard_initial
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
for j in 1:discard_initial
# Obtain the next sample and state.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be accounted for in the progress logger as well (as done currently).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good now 👍

sample, state = if j ≤ num_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

Suggested change
sample, state = if j num_warmup
sample, state = if j discard_num_warmup

shouldn't it?

Maybe it could even be split into two sequential for loops?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

I think technically it doesn't matter, right? Since we have either

  1. discard_num_warmup = discard_initial when num_warmup >= discard_initial, or
  2. discard_num_warmup = num_warmup when num_warmup < discard_initial.

In both of those cases we get the same behavior in the above.

But I think for readability's sake, I agree we should make the change! Just pointing out it shouldn't been a cause of a bug.

Maybe it could even be split into two sequential for loops?

Wait what, wasn't that what I had before? 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait what, wasn't that what I had before? confused

Really? I think you used a different logic initially but maybe I misremember 😄 In any case, I guess it does not matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean

for i = 1:discard_num_warmup
    # ...
end

for i = discard_num_warmup + 1:discard_initial
    # ...
end

?

Because you're probably right, I don't think I ever did this exactly 😬

I'm preferential to the current code for readability's sake because it means the discard stepping is looks the same as the proper stepping, code-wise.

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
end

# Run callback.
Expand All @@ -151,19 +192,16 @@ function mcmcsample(
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)

# Update the progress bar.
itotal = 1 + discard_initial
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

# Step through the sampler.
for i in 2:N
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Update progress bar.
if progress && (itotal += 1) >= next_update
Expand All @@ -173,7 +211,11 @@ function mcmcsample(
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
Expand Down Expand Up @@ -217,28 +259,51 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Convergence sampling",
callback=nothing,
discard_initial=0,
num_warmup=0,
discard_initial=num_warmup,
thinning=1,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
discard_initial >= 0 ||
throw(ArgumentError("number of discarded samples must be non-negative"))
num_warmup >= 0 ||
throw(ArgumentError("number of warm-up samples must be non-negative"))

# Determine how many samples to drop from `num_warmup` and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the same/similar error checks as above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup

# Start the timer
start = time()
local state

@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
sample, state = if num_warmup > 0
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
step(rng, model, sampler, state; kwargs...)
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Discard initial samples.
for _ in 1:discard_initial
for j in 1:discard_initial
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if j ≤ num_warmup
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
end

# Run callback.
Expand All @@ -250,16 +315,23 @@ function mcmcsample(

# Step through the sampler until stopping.
i = 2

while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...)
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't i be incremented at the top of the loop? Before it was 2 here, now it is 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, nice catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just reverted the initialization of i to 2.

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge this with the for-loop above AFAICT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye, but we can do that everywhere here no? I can make this change, but I'll wait until you've had a final look (to make the diff clearer).

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
Expand Down
39 changes: 39 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,45 @@
@test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N)
end

@testset "Warm-up steps" begin
# Create a chain and discard initial samples.
Random.seed!(1234)
N = 100
num_warmup = 50

# Everything should be discarded here.
chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup)
@test length(chain) == N
@test !ismissing(chain[1].a)

# Repeat sampling without discarding initial samples.
# On Julia < 1.6 progress logging changes the global RNG and hence is enabled here.
# https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258
Random.seed!(1234)
ref_chain = sample(
MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6"
)
@test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N)
@test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N)

# Some other stuff.
Random.seed!(1234)
discard_initial = 10
chain_warmup = sample(
MyModel(),
MySampler(),
N;
num_warmup=num_warmup,
discard_initial=discard_initial,
)
@test length(chain_warmup) == N
@test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N)
# Check that the first `num_warmup - discard_initial` samples are warmup samples.
@test all(
chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N
)
end

@testset "Thin chain by a factor of `thinning`" begin
# Run a thinned chain with `N` samples thinned by factor of `thinning`.
Random.seed!(100)
Expand Down
18 changes: 18 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end
struct MySample{A,B}
a::A
b::B
is_warmup::Bool
end

MySample(a, b) = MySample(a, b, false)

struct MySampler <: AbstractMCMC.AbstractSampler end
struct AnotherSampler <: AbstractMCMC.AbstractSampler end

Expand All @@ -16,6 +19,21 @@ end

MyChain(a, b) = MyChain(a, b, NamedTuple())

function AbstractMCMC.step_warmup(
rng::AbstractRNG,
model::MyModel,
sampler::MySampler,
state::Union{Nothing,Integer}=nothing;
loggers=false,
initial_params=nothing,
kwargs...,
)
transition, state = AbstractMCMC.step(
rng, model, sampler, state; loggers, initial_params, kwargs...
)
return MySample(transition.a, transition.b, true), state
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::MyModel,
Expand Down
Loading