-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of step_warmup
#117
Changes from 35 commits
0987f5f
30c9f12
7faa73f
bd0bdc7
c620cca
572a286
6b842ee
ca03832
0441773
ea369ff
8e0ca53
6877978
b3b3148
ddc5254
ffbd32f
87480ff
76f2f23
c00d0c9
ef09c19
49b8406
f005746
9dccd8a
ff00e6e
7ce9f6b
3a217b2
de9bb2c
85d938f
0a667a4
91f5a10
7603171
ef68d04
25afc66
1886fa8
0ea293a
6e8f88e
44c55bb
3b4f6db
f9142a6
295fdc1
e6acb1f
2e9fa5c
366fceb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,6 +40,11 @@ | |||||
``` | ||||||
where `state` and `iteration` are the current state and iteration of the sampler, respectively. | ||||||
It should return `true` when sampling should end, and `false` otherwise. | ||||||
|
||||||
# Keyword arguments | ||||||
|
||||||
See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword | ||||||
arguments. | ||||||
""" | ||||||
function StatsBase.sample( | ||||||
rng::Random.AbstractRNG, | ||||||
|
@@ -77,6 +82,11 @@ | |||||
|
||||||
Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel | ||||||
using the `parallel` algorithm, and combine them into a single chain. | ||||||
|
||||||
# Keyword arguments | ||||||
|
||||||
See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword | ||||||
arguments. | ||||||
""" | ||||||
function StatsBase.sample( | ||||||
rng::Random.AbstractRNG, | ||||||
|
@@ -91,7 +101,6 @@ | |||||
end | ||||||
|
||||||
# Default implementations of regular and parallel sampling. | ||||||
|
||||||
function mcmcsample( | ||||||
rng::Random.AbstractRNG, | ||||||
model::AbstractModel, | ||||||
|
@@ -100,15 +109,28 @@ | |||||
progress=PROGRESS[], | ||||||
progressname="Sampling", | ||||||
callback=nothing, | ||||||
discard_initial=0, | ||||||
num_warmup::Int=0, | ||||||
discard_initial::Int=num_warmup, | ||||||
thinning=1, | ||||||
chain_type::Type=Any, | ||||||
initial_state=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
# Check the number of requested samples. | ||||||
N > 0 || error("the number of samples must be ≥ 1") | ||||||
discard_initial >= 0 || | ||||||
throw(ArgumentError("number of discarded samples must be non-negative")) | ||||||
num_warmup >= 0 || | ||||||
throw(ArgumentError("number of warm-up samples must be non-negative")) | ||||||
Ntotal = thinning * (N - 1) + discard_initial + 1 | ||||||
Ntotal >= num_warmup || throw( | ||||||
ArgumentError("number of warm-up samples exceeds the total number of samples") | ||||||
) | ||||||
|
||||||
# Determine how many samples to drop from `num_warmup` and the | ||||||
# main sampling process before we start saving samples. | ||||||
discard_from_warmup = min(num_warmup, discard_initial) | ||||||
keep_from_warmup = num_warmup - discard_from_warmup | ||||||
|
||||||
# Start the timer | ||||||
start = time() | ||||||
|
@@ -123,22 +145,41 @@ | |||||
end | ||||||
|
||||||
# Obtain the initial sample and state. | ||||||
sample, state = if initial_state === nothing | ||||||
step(rng, model, sampler; kwargs...) | ||||||
sample, state = if num_warmup > 0 | ||||||
if initial_state === nothing | ||||||
step_warmup(rng, model, sampler; kwargs...) | ||||||
else | ||||||
step_warmup(rng, model, sampler, initial_state; kwargs...) | ||||||
end | ||||||
else | ||||||
step(rng, model, sampler, initial_state; kwargs...) | ||||||
if initial_state === nothing | ||||||
step(rng, model, sampler; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, initial_state; kwargs...) | ||||||
end | ||||||
end | ||||||
|
||||||
# Update the progress bar. | ||||||
itotal = 1 | ||||||
if progress && itotal >= next_update | ||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||
next_update = itotal + threshold | ||||||
end | ||||||
|
||||||
# Discard initial samples. | ||||||
for i in 1:discard_initial | ||||||
# Update the progress bar. | ||||||
if progress && i >= next_update | ||||||
ProgressLogging.@logprogress i / Ntotal | ||||||
next_update = i + threshold | ||||||
for j in 1:discard_initial | ||||||
# Obtain the next sample and state. | ||||||
sample, state = if j ≤ num_warmup | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be
Suggested change
shouldn't it? Maybe it could even be split into two sequential for loops? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think technically it doesn't matter, right? Since we have either
In both of those cases we get the same behavior in the above. But I think for readability's sake, I agree we should make the change! Just pointing out it shouldn't been a cause of a bug.
Wait what, wasn't that what I had before? 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Really? I think you used a different logic initially but maybe I misremember 😄 In any case, I guess it does not matter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean for i = 1:discard_num_warmup
# ...
end
for i = discard_num_warmup + 1:discard_initial
# ...
end ? Because you're probably right, I don't think I ever did this exactly 😬 I'm preferential to the current code for readability's sake because it means the discard stepping is looks the same as the proper stepping, code-wise. |
||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
|
||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
# Update the progress bar. | ||||||
if progress && (itotal += 1) >= next_update | ||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||
next_update = itotal + threshold | ||||||
end | ||||||
end | ||||||
|
||||||
# Run callback. | ||||||
|
@@ -148,19 +189,16 @@ | |||||
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) | ||||||
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) | ||||||
|
||||||
# Update the progress bar. | ||||||
itotal = 1 + discard_initial | ||||||
if progress && itotal >= next_update | ||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||
next_update = itotal + threshold | ||||||
end | ||||||
|
||||||
# Step through the sampler. | ||||||
for i in 2:N | ||||||
# Discard thinned samples. | ||||||
for _ in 1:(thinning - 1) | ||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
sample, state = if i ≤ keep_from_warmup | ||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
|
||||||
# Update progress bar. | ||||||
if progress && (itotal += 1) >= next_update | ||||||
|
@@ -170,7 +208,11 @@ | |||||
end | ||||||
|
||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
sample, state = if i ≤ keep_from_warmup | ||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
|
||||||
# Run callback. | ||||||
callback === nothing || | ||||||
|
@@ -179,6 +221,9 @@ | |||||
# Save the sample. | ||||||
samples = save!!(samples, sample, i, model, sampler, N; kwargs...) | ||||||
|
||||||
# Increment iteration counter. | ||||||
i += 1 | ||||||
|
||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
# Update the progress bar. | ||||||
if progress && (itotal += 1) >= next_update | ||||||
ProgressLogging.@logprogress itotal / Ntotal | ||||||
|
@@ -214,28 +259,45 @@ | |||||
progress=PROGRESS[], | ||||||
progressname="Convergence sampling", | ||||||
callback=nothing, | ||||||
discard_initial=0, | ||||||
num_warmup=0, | ||||||
discard_initial=num_warmup, | ||||||
thinning=1, | ||||||
initial_state=nothing, | ||||||
kwargs..., | ||||||
) | ||||||
# Determine how many samples to drop from `num_warmup` and the | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add the same/similar error checks as above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done 👍 |
||||||
# main sampling process before we start saving samples. | ||||||
discard_from_warmup = min(num_warmup, discard_initial) | ||||||
keep_from_warmup = num_warmup - discard_from_warmup | ||||||
|
||||||
# Start the timer | ||||||
start = time() | ||||||
local state | ||||||
|
||||||
@ifwithprogresslogger progress name = progressname begin | ||||||
# Obtain the initial sample and state. | ||||||
sample, state = if initial_state === nothing | ||||||
step(rng, model, sampler; kwargs...) | ||||||
sample, state = if num_warmup > 0 | ||||||
if initial_state === nothing | ||||||
step_warmup(rng, model, sampler; kwargs...) | ||||||
else | ||||||
step_warmup(rng, model, sampler, initial_state; kwargs...) | ||||||
end | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
if initial_state === nothing | ||||||
step(rng, model, sampler; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, initial_state; kwargs...) | ||||||
end | ||||||
end | ||||||
|
||||||
# Discard initial samples. | ||||||
for _ in 1:discard_initial | ||||||
for j in 1:discard_initial | ||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
sample, state = if j ≤ discard_from_warmup | ||||||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
end | ||||||
|
||||||
# Run callback. | ||||||
|
@@ -247,16 +309,23 @@ | |||||
|
||||||
# Step through the sampler until stopping. | ||||||
i = 2 | ||||||
|
||||||
while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) | ||||||
# Discard thinned samples. | ||||||
for _ in 1:(thinning - 1) | ||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
sample, state = if i ≤ keep_from_warmup | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, nice catch! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just reverted the initialization of |
||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
end | ||||||
|
||||||
# Obtain the next sample and state. | ||||||
sample, state = step(rng, model, sampler, state; kwargs...) | ||||||
sample, state = if i ≤ keep_from_warmup | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could merge this with the for-loop above AFAICT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aye, but we can do that everywhere here no? I can make this change, but I'll wait until you've had a final look (to make the diff clearer). |
||||||
step_warmup(rng, model, sampler, state; kwargs...) | ||||||
else | ||||||
step(rng, model, sampler, state; kwargs...) | ||||||
end | ||||||
|
||||||
# Run callback. | ||||||
callback === nothing || | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these should be accounted for in the progress logger as well (as done currently).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be good now 👍