diff --git a/Project.toml b/Project.toml index cdffc3b2..991eb18b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,15 @@ version = "0.4.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 3f87c206..873988eb 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -5,16 +5,17 @@ using LinearAlgebra: LinearAlgebra using Statistics: Statistics using UnPack: @unpack, @pack! -using ..AdvancedHMC: DEBUG, AbstractScalarOrVec +using ..AdvancedHMC: DEBUG, AbstractScalarOrVec, PhasePoint abstract type AbstractAdaptor end +abstract type AbstractHMCAdaptorWithGradients <: AbstractAdaptor end function getM⁻¹ end function getϵ end function adapt! end function reset! end function initialize! end function finalize! end -export AbstractAdaptor, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ +export AbstractAdaptor, AbstractHMCAdaptorWithGradients, adapt!, initialize!, finalize!, reset!, getϵ, getM⁻¹ struct NoAdaptation <: AbstractAdaptor end export NoAdaptation @@ -57,4 +58,11 @@ finalize!(aca::NaiveHMCAdaptor) = finalize!(aca.ssa) include("stan_adaptor.jl") export NaiveHMCAdaptor, StanHMCAdaptor +const LOWER_LIMIT::Float64 = 1e-10 +const UPPER_LIMIT::Float64 = 1e10 + +include("nutpie_adaptor.jl") +export NutpieHMCAdaptor, ExpWeightedWelfordVar, NutpieHMCAdaptorNoSwitch, NutpieHMCAdaptorNoGradInit,NutpieHMCAdaptorNoSwitchNoGradInit + + end # module diff --git a/src/adaptation/massmatrix.jl b/src/adaptation/massmatrix.jl index 9e550c3e..f214eea7 100644 --- a/src/adaptation/massmatrix.jl +++ b/src/adaptation/massmatrix.jl @@ -11,7 +11,7 @@ function adapt!( adaptor::MassMatrixAdaptor, θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, - is_update::Bool = true, + is_update::Bool=true, ) resize!(adaptor, θ) push!(adaptor, θ) @@ -38,7 +38,7 @@ adapt!( ::UnitMassMatrix, ::AbstractVecOrMat{<:AbstractFloat}, ::AbstractScalarOrVec{<:AbstractFloat}, - is_update::Bool = true, + is_update::Bool=true, ) = nothing ## Diagonal mass matrix adaptor @@ -91,8 +91,8 @@ Base.show(io::IO, ::WelfordVar) = print(io, "WelfordVar") function WelfordVar{T}( sz::Union{Tuple{Int},Tuple{Int,Int}}; - n_min::Int = 10, - var = ones(T, sz), + n_min::Int=10, + var=ones(T, sz) ) where {T<:AbstractFloat} return WelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) end @@ -133,6 +133,96 @@ function get_estimation(wv::WelfordVar{T}) where {T<:AbstractFloat} return n / ((n + 5) * (n - 1)) * M .+ ϵ * (5 / (n + 5)) end +# Rust implementation of NUTS used in Nutpie (comes from nuts-rs crate) +# Source: https://github.com/pymc-devs/nuts-rs/blob/main/src/adapt_strategy.rs + +mutable struct ExpWeightedWelfordVar{T<:AbstractFloat,E<:AbstractVecOrMat{T}} <: DiagMatrixEstimator{T} + exp_variance_draw::WelfordVar{T,E} + exp_variance_grad::WelfordVar{T,E} + exp_variance_draw_bg::WelfordVar{T,E} + exp_variance_grad_bg::WelfordVar{T,E} + function ExpWeightedWelfordVar(exp_variance_draw::WelfordVar{T,E}, exp_variance_grad::WelfordVar{T,E}, exp_variance_draw_bg::WelfordVar{T,E}, exp_variance_grad_bg::WelfordVar{T,E}) where {T,E} + return new{eltype(E),E}(exp_variance_draw, exp_variance_grad, exp_variance_draw_bg, exp_variance_grad_bg) + end +end + +# save the best estimate of the variance in the "current" WelfordVar +getM⁻¹(ve::ExpWeightedWelfordVar) = ve.exp_variance_draw.var + +Base.show(io::IO, ::ExpWeightedWelfordVar) = print(io, "ExpWeightedWelfordVar") + +function ExpWeightedWelfordVar{T}( + sz::Union{Tuple{Int},Tuple{Int,Int}}; + n_min::Int=4, var=ones(T, sz) +) where {T<:AbstractFloat} + # return ExpWeightedWelfordVar(0, n_min, zeros(T, sz), zeros(T, sz), zeros(T, sz), var) + return ExpWeightedWelfordVar(WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var), WelfordVar{T}(sz; n_min, var)) +end + +ExpWeightedWelfordVar(sz::Union{Tuple{Int},Tuple{Int,Int}}; kwargs...) = ExpWeightedWelfordVar{Float64}(sz; kwargs...) + +function Base.resize!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T<:AbstractFloat} + @assert size(θ) == size(∇logπ) "Size of draw and grad must be the same." + resize!(wv.exp_variance_draw, θ) + resize!(wv.exp_variance_grad, ∇logπ) + resize!(wv.exp_variance_draw_bg, θ) + resize!(wv.exp_variance_grad_bg, ∇logπ) +end + +function reset!(wv::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} + reset!(wv.exp_variance_draw) + reset!(wv.exp_variance_grad) + reset!(wv.exp_variance_draw_bg) + reset!(wv.exp_variance_grad_bg) +end + +function Base.push!(wv::ExpWeightedWelfordVar, θ::AbstractVecOrMat{T}, ∇logπ::AbstractVecOrMat{T}) where {T} + @assert size(θ) == size(∇logπ) "Size of draw and grad must be the same." + push!(wv.exp_variance_draw, θ) + push!(wv.exp_variance_grad, ∇logπ) + push!(wv.exp_variance_draw_bg, θ) + push!(wv.exp_variance_grad_bg, ∇logπ) +end + +# swap the background and foreground estimators for both _draw and _grad variance +# unlike the Rust implementation, we don't update the estimators inside of the switch as well (called separately) +function switch!(wv::ExpWeightedWelfordVar) + wv.exp_variance_draw = wv.exp_variance_draw_bg + reset!(wv.exp_variance_draw_bg) + wv.exp_variance_grad = wv.exp_variance_grad_bg + reset!(wv.exp_variance_grad_bg) +end +current_count(wv) = wv.exp_variance_draw.n +background_count(wv) = wv.exp_variance_draw_bg.n + +function adapt!( + adaptor::ExpWeightedWelfordVar, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + ∇logπ::AbstractVecOrMat{<:AbstractFloat}, + is_update::Bool=true +) + resize!(adaptor, θ, ∇logπ) + push!(adaptor, θ, ∇logπ) + is_update && update!(adaptor) +end + +# TODO: handle NaN +function get_estimation(ad::ExpWeightedWelfordVar{T}) where {T<:AbstractFloat} + var_draw = get_estimation(ad.exp_variance_draw) + var_grad = get_estimation(ad.exp_variance_grad) + # mimics: let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT); + var = (var_draw ./ var_grad) .|> sqrt .|> x -> clamp(x, LOWER_LIMIT, UPPER_LIMIT) + # re-use the last estimate `var` if the current estimate is not valid + return all(isfinite.(var)) ? var : ad.exp_variance_draw.var +end +# reuse the `var` slot in the `exp_variance_draw` (which is `WelfordVar`) +# to store the estimated variance of the draw (the "current" / "foreground" one) +function update!(ad::ExpWeightedWelfordVar) + current_count(ad) >= ad.exp_variance_draw.n_min && (ad.exp_variance_draw.var .= get_estimation(ad)) +end + + ## Dense mass matrix adaptor abstract type DenseMatrixEstimator{T} <: MassMatrixAdaptor end @@ -179,8 +269,8 @@ Base.show(io::IO, ::WelfordCov) = print(io, "WelfordCov") function WelfordCov{T}( sz::Tuple{Int}; - n_min::Int = 10, - cov = LinearAlgebra.diagm(0 => ones(T, first(sz))), + n_min::Int=10, + cov=LinearAlgebra.diagm(0 => ones(T, first(sz))) ) where {T<:AbstractFloat} d = first(sz) return WelfordCov(0, n_min, zeros(T, d), zeros(T, d, d), zeros(T, d), cov) diff --git a/src/adaptation/nutpie_adaptor.jl b/src/adaptation/nutpie_adaptor.jl new file mode 100644 index 00000000..225bc558 --- /dev/null +++ b/src/adaptation/nutpie_adaptor.jl @@ -0,0 +1,264 @@ +####################################3 +### General methods +# it will then be forwarded to `adaptor` (there it resides in `exp_variance_draw.var`) +getM⁻¹(ca::AbstractHMCAdaptorWithGradients) = getM⁻¹(ca.pc) +getϵ(ca::AbstractHMCAdaptorWithGradients) = getϵ(ca.ssa) +finalize!(adaptor::AbstractHMCAdaptorWithGradients) = finalize!(adaptor.ssa) + +### Mutable states +mutable struct NutpieHMCAdaptorState + i::Int + n_adapts::Int + # The number of draws in the the early window + early_end::Int + # The first draw number for the final step size adaptation window + final_step_size_window::Int + + function NutpieHMCAdaptorState(i, n_adapts, early_end, final_step_size_window) + @assert (early_end < n_adapts) "Early_end must be less than num_tune (provided $early_end and $n_adapts)" + return new(i, n_adapts, early_end, final_step_size_window) + end +end +function NutpieHMCAdaptorState() + return NutpieHMCAdaptorState(0, 1000, 300, 800) +end + +function initialize!(state::NutpieHMCAdaptorState, early_window_share::Float64, + final_step_size_window_share::Float64, + n_adapts::Int) + + early_end = ceil(UInt64, early_window_share * n_adapts) + step_size_window = ceil(UInt64, final_step_size_window_share * n_adapts) + final_step_size_window = max(n_adapts - step_size_window, 0) + 1 + + state.early_end = early_end + state.n_adapts = n_adapts + state.final_step_size_window = final_step_size_window +end + +# function Base.show(io::IO, state::NutpieHMCAdaptorState) +# print(io, "window($(state.window_start), $(state.window_end)), window_splits(" * string(join(state.window_splits, ", ")) * ")") +# end + +### Nutpie's adaptation +# Acknowledgement: ... +struct NutpieHMCAdaptor{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + early_window_share::Float64 + final_step_size_window_share::Float64 + mass_matrix_switch_freq::Int + early_mass_matrix_switch_freq::Int + state::NutpieHMCAdaptorState +end +# Base.show(io::IO, a::NutpieHMCAdaptor) = +# print(io, "NutpieHMCAdaptor(\n pc=$(a.pc),\n ssa=$(a.ssa),\n init_buffer=$(a.init_buffer), term_buffer=$(a.term_buffer), window_size=$(a.window_size),\n state=$(a.state)\n)") + +function NutpieHMCAdaptor( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + early_window_share::Float64=0.3, + final_step_size_window_share::Float64=0.2, + mass_matrix_switch_freq::Int=60, + early_mass_matrix_switch_freq::Int=10 +) + return NutpieHMCAdaptor(pc, ssa, early_window_share, final_step_size_window_share, mass_matrix_switch_freq, early_mass_matrix_switch_freq, NutpieHMCAdaptorState()) +end + +function initialize!(adaptor::NutpieHMCAdaptor, n_adapts::Int, ∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!(adaptor.state, adaptor.early_window_share, adaptor.final_step_size_window_share, n_adapts) + # !Q: Shall we initialize from the gradient? + # Nutpie initializes the variance estimate with reciprocal of the gradient + # Like: Some((grad).abs().recip().clamp(LOWER_LIMIT, UPPER_LIMIT)) + # TODO: point to var more dynamically + adaptor.pc.exp_variance_draw.var = (1 ./ abs.(∇logπ)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) + return adaptor +end + +#################################### +## Special case: Skip the initiation of the mass matrix with gradient +struct NutpieHMCAdaptorNoGradInit{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + early_window_share::Float64 + final_step_size_window_share::Float64 + mass_matrix_switch_freq::Int + early_mass_matrix_switch_freq::Int + state::NutpieHMCAdaptorState +end +# Base.show(io::IO, a::NutpieHMCAdaptor) = +# print(io, "NutpieHMCAdaptor(\n pc=$(a.pc),\n ssa=$(a.ssa),\n init_buffer=$(a.init_buffer), term_buffer=$(a.term_buffer), window_size=$(a.window_size),\n state=$(a.state)\n)") + +function NutpieHMCAdaptorNoGradInit( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + early_window_share::Float64=0.3, + final_step_size_window_share::Float64=0.2, + mass_matrix_switch_freq::Int=60, + early_mass_matrix_switch_freq::Int=10 +) + return NutpieHMCAdaptorNoGradInit(pc, ssa, early_window_share, final_step_size_window_share, mass_matrix_switch_freq, early_mass_matrix_switch_freq, NutpieHMCAdaptorState()) +end +function initialize!(adaptor::NutpieHMCAdaptorNoGradInit, n_adapts::Int, ∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!(adaptor.state, adaptor.early_window_share, adaptor.final_step_size_window_share, n_adapts) + return adaptor +end +#################################### +## Special case: No switching, use StanHMCAdaptor-like strategy (but keep var+gradients) +struct NutpieHMCAdaptorNoSwitch{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + init_buffer::Int + term_buffer::Int + window_size::Int + state::StanHMCAdaptorState +end + +function NutpieHMCAdaptorNoSwitch( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + init_buffer::Int = 75, + term_buffer::Int = 50, + window_size::Int = 25, +) + return NutpieHMCAdaptorNoSwitch( + pc, + ssa, + init_buffer, + term_buffer, + window_size, + StanHMCAdaptorState(), + ) +end + +function initialize!(adaptor::NutpieHMCAdaptorNoSwitch, n_adapts::Int,∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!( + adaptor.state, + adaptor.init_buffer, + adaptor.term_buffer, + adaptor.window_size, + n_adapts, + ) + adaptor.pc.exp_variance_draw.var = (1 ./ abs.(∇logπ)) |> x -> clamp.(x, LOWER_LIMIT, UPPER_LIMIT) + return adaptor +end + +############################################ +## Special case: No switching, use StanHMCAdaptor-like strategy (but keep var+gradients) +## Both switching and grad init disabled +struct NutpieHMCAdaptorNoSwitchNoGradInit{M<:MassMatrixAdaptor,Tssa<:StepSizeAdaptor} <: AbstractHMCAdaptorWithGradients + pc::M + ssa::Tssa + init_buffer::Int + term_buffer::Int + window_size::Int + state::StanHMCAdaptorState +end + +function NutpieHMCAdaptorNoSwitchNoGradInit( + pc::ExpWeightedWelfordVar, + ssa::StepSizeAdaptor; + init_buffer::Int = 75, + term_buffer::Int = 50, + window_size::Int = 25, +) + return NutpieHMCAdaptorNoSwitchNoGradInit( + pc, + ssa, + init_buffer, + term_buffer, + window_size, + StanHMCAdaptorState(), + ) +end + +function initialize!(adaptor::NutpieHMCAdaptorNoSwitchNoGradInit, n_adapts::Int,∇logπ::AbstractVecOrMat{<:AbstractFloat}) + initialize!( + adaptor.state, + adaptor.init_buffer, + adaptor.term_buffer, + adaptor.window_size, + n_adapts, + ) + return adaptor +end + + +##################################### +# Adaptation: main case ala Nutpie +# +# Changes vs Rust implementation +# - step_size_adapt is at the top +# - several checks are handled in sampler (finalize adaptation, does not adapt during normal sampling) +# - switch and push/update are handled separately to mimic the StanHMCAdaptor +# +# Missing: +# - collector checks on divergences or terminating at idx=0 // finite and None esimtaor +# - adapt! estimator only if collected stuff is good (divergences) +# - init for mass_matrix is grad.abs().recip.clamp(LOWER_LIMIT, UPPER_LIMIT) // init of ExpWindowDiagAdapt +# +is_in_first_step_size_window(tp::AbstractHMCAdaptorWithGradients) = tp.state.i <= tp.state.final_step_size_window +is_in_early_window(tp::AbstractHMCAdaptorWithGradients) = tp.state.i <= tp.state.early_end +switch_freq(tp::AbstractHMCAdaptorWithGradients) = is_in_early_window(tp) ? tp.early_mass_matrix_switch_freq : tp.mass_matrix_switch_freq +# +function adapt!( + tp::Union{NutpieHMCAdaptor,NutpieHMCAdaptorNoGradInit}, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + ∇logπ::AbstractVecOrMat{<:AbstractFloat} +) + tp.state.i += 1 + + adapt!(tp.ssa, θ, α) + + # TODO: do we resize twice? also in update? + # !Q: Why do we check resizing several times during iteration? (also in adapt!) + resize!(tp.pc, θ, ∇logπ) # Resize pre-conditioner if necessary. + + # determine whether to update mass matrix + if is_in_first_step_size_window(tp) + + # Switch swaps the background (_bg) values for current, and resets the background values + # Frequency of the switch depends on the phase + background_count(tp.pc) >= switch_freq(tp) && switch!(tp.pc) + + # TODO: implement a skipper for bad draws + # !Q: Why does it always update? (as per Nuts-rs/Nutpie) + adapt!(tp.pc, θ, α, ∇logπ, true) + end +end + +##################################### +# Adaptation: No switching - ala StanHMCAdaptor +# +is_in_window(tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}) = + tp.state.i >= tp.state.window_start && tp.state.i <= tp.state.window_end +is_window_end(tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}) = tp.state.i in tp.state.window_splits +# +function adapt!( + tp::Union{NutpieHMCAdaptorNoSwitch,NutpieHMCAdaptorNoSwitchNoGradInit}, + θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, + ∇logπ::AbstractVecOrMat{<:AbstractFloat} +) + tp.state.i += 1 + + adapt!(tp.ssa, θ, α) + + resize!(tp.pc, θ, ∇logπ) # Resize pre-conditioner if necessary. + + # Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp + if is_in_window(tp) + # We accumlate stats from θ online and only trigger the update of M⁻¹ in the end of window. + is_update_M⁻¹ = is_window_end(tp) + adapt!(tp.pc, θ, α, ∇logπ, is_update_M⁻¹) + end + + if is_window_end(tp) + reset!(tp.ssa) + reset!(tp.pc) + end +end + + + diff --git a/src/sampler.jl b/src/sampler.jl index 7d1b7eb5..b0bf1b43 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -75,9 +75,11 @@ function Adaptation.adapt!( adaptor::AbstractAdaptor, i::Int, n_adapts::Int, - θ::AbstractVecOrMat{<:AbstractFloat}, + z::PhasePoint, + # θ::AbstractVecOrMat{<:AbstractFloat}, α::AbstractScalarOrVec{<:AbstractFloat}, ) + θ = z.θ isadapted = false if i <= n_adapts i == 1 && Adaptation.initialize!(adaptor, n_adapts) @@ -90,11 +92,36 @@ function Adaptation.adapt!( return h, κ, isadapted end +# Nutpie adaptor requires access to gradients in the Hamiltonian +function Adaptation.adapt!( + h::Hamiltonian, + κ::AbstractMCMCKernel, + adaptor::AbstractHMCAdaptorWithGradients, + i::Int, + n_adapts::Int, + z::PhasePoint, + # θ::AbstractVecOrMat{<:AbstractFloat}, + α::AbstractScalarOrVec{<:AbstractFloat}, +) + θ = z.θ + ∇logπ = z.ℓπ.gradient + isadapted = false + if i <= n_adapts + i == 1 && Adaptation.initialize!(adaptor, n_adapts,∇logπ) + adapt!(adaptor, θ, α,∇logπ) + i == n_adapts && finalize!(adaptor) + h = update(h, adaptor) + κ = update(κ, adaptor) + isadapted = true + end + return h, κ, isadapted +end + """ Progress meter update with all trajectory stats, iteration number and metric shown. """ function pm_next!(pm, stat::NamedTuple) - ProgressMeter.next!(pm; showvalues = [tuple(s...) for s in pairs(stat)]) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stat)]) end """ @@ -111,12 +138,12 @@ sample( κ::AbstractMCMCKernel, θ::AbstractVecOrMat{<:AbstractFloat}, n_samples::Int, - adaptor::AbstractAdaptor = NoAdaptation(), - n_adapts::Int = min(div(n_samples, 10), 1_000); - drop_warmup = false, - verbose::Bool = true, - progress::Bool = false, - (pm_next!)::Function = pm_next!, + adaptor::AbstractAdaptor=NoAdaptation(), + n_adapts::Int=min(div(n_samples, 10), 1_000); + drop_warmup=false, + verbose::Bool=true, + progress::Bool=false, + (pm_next!)::Function=pm_next! ) = sample( GLOBAL_RNG, h, @@ -125,10 +152,10 @@ sample( n_samples, adaptor, n_adapts; - drop_warmup = drop_warmup, - verbose = verbose, - progress = progress, - (pm_next!) = pm_next!, + drop_warmup=drop_warmup, + verbose=verbose, + progress=progress, + (pm_next!)=pm_next! ) """ @@ -160,12 +187,12 @@ function sample( κ::HMCKernel, θ::T, n_samples::Int, - adaptor::AbstractAdaptor = NoAdaptation(), - n_adapts::Int = min(div(n_samples, 10), 1_000); - drop_warmup = false, - verbose::Bool = true, - progress::Bool = false, - (pm_next!)::Function = pm_next!, + adaptor::AbstractAdaptor=NoAdaptation(), + n_adapts::Int=min(div(n_samples, 10), 1_000); + drop_warmup=false, + verbose::Bool=true, + progress::Bool=false, + (pm_next!)::Function=pm_next! ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results @@ -177,7 +204,7 @@ function sample( h, t = sample_init(rng, h, θ) # Progress meter pm = - progress ? ProgressMeter.Progress(n_samples, desc = "Sampling", barlen = 31) : + progress ? ProgressMeter.Progress(n_samples, desc="Sampling", barlen=31) : nothing time = @elapsed for i = 1:n_samples # Make a transition @@ -185,13 +212,13 @@ function sample( # Adapt h and κ; what mutable is the adaptor tstat = stat(t) h, κ, isadapted = - adapt!(h, κ, adaptor, i, n_adapts, t.z.θ, tstat.acceptance_rate) + adapt!(h, κ, adaptor, i, n_adapts, t.z, tstat.acceptance_rate) if isadapted num_divergent_transitions_during_adaption += tstat.numerical_error else num_divergent_transitions += tstat.numerical_error end - tstat = merge(tstat, (is_adapt = isadapted,)) + tstat = merge(tstat, (is_adapt=isadapted,)) # Update progress meter if progress percentage_divergent_transitions = num_divergent_transitions / i @@ -205,17 +232,17 @@ function sample( pm_next!( pm, ( - iterations = i, - ratio_divergent_transitions = round( + iterations=i, + ratio_divergent_transitions=round( percentage_divergent_transitions; - digits = 2, + digits=2 ), - ratio_divergent_transitions_during_adaption = round( + ratio_divergent_transitions_during_adaption=round( percentage_divergent_transitions_during_adaption; - digits = 2, + digits=2 ), tstat..., - mass_matrix = h.metric, + mass_matrix=h.metric, ), ) # Report finish of adapation